use ndarray::{Array1, Array2, ArrayView1};
use std::ops::Range;
use std::sync::Arc;
use crate::cache::Fingerprinter;
use crate::linalg::faer_ndarray::{FaerArrayView, FaerLlt};
use crate::solver::arrow_schur_beta_graph::BetaCouplingGraph;
use crate::terms::analytic_penalties::{AnalyticPenaltyKind, AnalyticPenaltyRegistry, PenaltyTier};
use crate::terms::latent_coord::{LatentCoordValues, LatentManifold};
const DIRECT_SOLVE_MAX_K: usize = 2_000;
const DEFAULT_PCG_MAX_ITERATIONS: usize = 200;
const DEFAULT_PCG_RELATIVE_TOLERANCE: f64 = 1e-4;
const DEFAULT_TRUST_REGION_RADIUS: f64 = f64::INFINITY;
pub const DEFAULT_PROXIMAL_INITIAL_RIDGE: f64 = 1e-8;
pub const DEFAULT_PROXIMAL_RIDGE_GROWTH: f64 = 10.0;
pub const DEFAULT_PROXIMAL_MAX_ATTEMPTS: usize = 16;
const DEFAULT_ARMIJO_C1: f64 = 1e-4;
const DEFAULT_GRADIENT_TOLERANCE: f64 = 1e-10;
const EUCLIDEAN_MANIFOLD_MODE_FINGERPRINT: u64 = 0;
const ARROW_FACTOR_CACHE_HTBETA_BUDGET_BYTES: usize = 256 * 1024 * 1024;
pub type SharedBetaMatvec =
Arc<dyn for<'a> Fn(ArrayView1<'a, f64>, &mut Array1<f64>) + Send + Sync>;
pub type RowHtbetaMatvec =
Arc<dyn for<'a> Fn(usize, ArrayView1<'a, f64>, &mut Array1<f64>) + Send + Sync>;
pub type StreamingArrowRowBuilder =
Arc<dyn Fn(usize) -> Result<ArrowRowBlock, ArrowSchurError> + Send + Sync>;
pub type GpuSchurMatvec = Arc<dyn Fn(&Array1<f64>, &mut Array1<f64>) + Send + Sync>;
type MetricWeights = [f64];
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct BetaBlockId(pub usize);
pub trait BetaPenaltyOp: Send + Sync {
fn dim(&self) -> usize;
fn matvec(&self, x: &[f64], y: &mut [f64]);
fn gradient(&self, beta: &[f64], out: &mut [f64]);
fn diagonal(&self, diag: &mut [f64]);
fn block(&self, id: BetaBlockId, offsets: &[Range<usize>], out: &mut Array2<f64>);
fn to_dense(&self) -> Array2<f64>;
fn fingerprint(&self, hasher: &mut Fingerprinter);
}
pub struct DensePenaltyOp(pub Array2<f64>);
impl BetaPenaltyOp for DensePenaltyOp {
fn dim(&self) -> usize {
self.0.nrows()
}
fn matvec(&self, x: &[f64], y: &mut [f64]) {
let k = self.0.nrows();
for a in 0..k {
let mut acc = 0.0_f64;
for b in 0..k {
acc += self.0[[a, b]] * x[b];
}
y[a] += acc;
}
}
fn gradient(&self, beta: &[f64], out: &mut [f64]) {
let k = self.0.nrows();
for a in 0..k {
let mut acc = 0.0_f64;
for b in 0..k {
acc += self.0[[a, b]] * beta[b];
}
out[a] += acc;
}
}
fn diagonal(&self, diag: &mut [f64]) {
let k = self.0.nrows().min(diag.len());
for j in 0..k {
diag[j] += self.0[[j, j]];
}
}
fn block(&self, id: BetaBlockId, offsets: &[Range<usize>], out: &mut Array2<f64>) {
let range = &offsets[id.0];
let b = range.end - range.start;
for bi in 0..b {
for bj in 0..b {
out[[bi, bj]] += self.0[[range.start + bi, range.start + bj]];
}
}
}
fn to_dense(&self) -> Array2<f64> {
self.0.clone()
}
fn fingerprint(&self, hasher: &mut Fingerprinter) {
hasher.write_str("dense-penalty-op-v1");
write_array2_fingerprint(hasher, &self.0);
}
}
pub struct BlockPenaltyOp {
pub k: usize,
pub blocks: Vec<(usize, Array2<f64>)>,
}
impl BetaPenaltyOp for BlockPenaltyOp {
fn dim(&self) -> usize {
self.k
}
fn matvec(&self, x: &[f64], y: &mut [f64]) {
for (off, local) in &self.blocks {
let b = local.nrows();
for i in 0..b {
let gi = off + i;
let mut acc = 0.0_f64;
for j in 0..b {
acc += local[[i, j]] * x[off + j];
}
y[gi] += acc;
}
}
}
fn gradient(&self, beta: &[f64], out: &mut [f64]) {
for (off, local) in &self.blocks {
let b = local.nrows();
for i in 0..b {
let gi = off + i;
let mut acc = 0.0_f64;
for j in 0..b {
acc += local[[i, j]] * beta[off + j];
}
out[gi] += acc;
}
}
}
fn diagonal(&self, diag: &mut [f64]) {
for (off, local) in &self.blocks {
let b = local.nrows();
for j in 0..b {
diag[off + j] += local[[j, j]];
}
}
}
fn block(&self, id: BetaBlockId, offsets: &[Range<usize>], out: &mut Array2<f64>) {
let range = &offsets[id.0];
let b_out = range.end - range.start;
for (off, local) in &self.blocks {
let b = local.nrows();
let block_end = off + b;
if block_end <= range.start || *off >= range.end {
continue;
}
for bi in 0..b_out {
let gi = range.start + bi;
if gi < *off || gi >= block_end {
continue;
}
let li = gi - off;
for bj in 0..b_out {
let gj = range.start + bj;
if gj < *off || gj >= block_end {
continue;
}
let lj = gj - off;
out[[bi, bj]] += local[[li, lj]];
}
}
}
}
fn to_dense(&self) -> Array2<f64> {
let mut out = Array2::<f64>::zeros((self.k, self.k));
for (off, local) in &self.blocks {
let b = local.nrows();
for i in 0..b {
for j in 0..b {
out[[off + i, off + j]] += local[[i, j]];
}
}
}
out
}
fn fingerprint(&self, hasher: &mut Fingerprinter) {
hasher.write_str("block-penalty-op-v1");
hasher.write_usize(self.k);
hasher.write_usize(self.blocks.len());
for (off, local) in &self.blocks {
hasher.write_usize(*off);
write_array2_fingerprint(hasher, local);
}
}
}
pub struct KroneckerPenaltyOp {
pub factor_a: Array2<f64>,
pub factor_b: Array2<f64>,
pub global_offset: usize,
pub k: usize,
}
impl BetaPenaltyOp for KroneckerPenaltyOp {
fn dim(&self) -> usize {
self.k
}
fn matvec(&self, x: &[f64], y: &mut [f64]) {
let p_a = self.factor_a.nrows();
let p_b = self.factor_b.nrows();
let off = self.global_offset;
for i_a in 0..p_a {
for i_b in 0..p_b {
let gi = off + i_a * p_b + i_b;
let mut acc = 0.0_f64;
for j_a in 0..p_a {
let a_ij = self.factor_a[[i_a, j_a]];
if a_ij == 0.0 {
continue;
}
for j_b in 0..p_b {
acc += a_ij * self.factor_b[[i_b, j_b]] * x[off + j_a * p_b + j_b];
}
}
y[gi] += acc;
}
}
}
fn gradient(&self, beta: &[f64], out: &mut [f64]) {
let p_a = self.factor_a.nrows();
let p_b = self.factor_b.nrows();
let off = self.global_offset;
for i_a in 0..p_a {
for i_b in 0..p_b {
let gi = off + i_a * p_b + i_b;
let mut acc = 0.0_f64;
for j_a in 0..p_a {
let a_ij = self.factor_a[[i_a, j_a]];
if a_ij == 0.0 {
continue;
}
for j_b in 0..p_b {
acc += a_ij * self.factor_b[[i_b, j_b]] * beta[off + j_a * p_b + j_b];
}
}
out[gi] += acc;
}
}
}
fn diagonal(&self, diag: &mut [f64]) {
let p_a = self.factor_a.nrows();
let p_b = self.factor_b.nrows();
let off = self.global_offset;
for i_a in 0..p_a {
for i_b in 0..p_b {
diag[off + i_a * p_b + i_b] +=
self.factor_a[[i_a, i_a]] * self.factor_b[[i_b, i_b]];
}
}
}
fn block(&self, id: BetaBlockId, offsets: &[Range<usize>], out: &mut Array2<f64>) {
let range = &offsets[id.0];
let b = range.end - range.start;
let p_a = self.factor_a.nrows();
let p_b = self.factor_b.nrows();
let off = self.global_offset;
let block_end = off + p_a * p_b;
if block_end <= range.start || off >= range.end {
return;
}
for bi in 0..b {
let gi = range.start + bi;
if gi < off || gi >= block_end {
continue;
}
let li = gi - off;
let i_a = li / p_b;
let i_b = li % p_b;
for bj in 0..b {
let gj = range.start + bj;
if gj < off || gj >= block_end {
continue;
}
let lj = gj - off;
let j_a = lj / p_b;
let j_b = lj % p_b;
out[[bi, bj]] += self.factor_a[[i_a, j_a]] * self.factor_b[[i_b, j_b]];
}
}
}
fn to_dense(&self) -> Array2<f64> {
let p_a = self.factor_a.nrows();
let p_b = self.factor_b.nrows();
let off = self.global_offset;
let mut out = Array2::<f64>::zeros((self.k, self.k));
for i_a in 0..p_a {
for i_b in 0..p_b {
let gi = off + i_a * p_b + i_b;
for j_a in 0..p_a {
let a_ij = self.factor_a[[i_a, j_a]];
if a_ij == 0.0 {
continue;
}
for j_b in 0..p_b {
let gj = off + j_a * p_b + j_b;
out[[gi, gj]] += a_ij * self.factor_b[[i_b, j_b]];
}
}
}
}
out
}
fn fingerprint(&self, hasher: &mut Fingerprinter) {
hasher.write_str("kronecker-penalty-op-v1");
hasher.write_usize(self.global_offset);
hasher.write_usize(self.k);
write_array2_fingerprint(hasher, &self.factor_a);
write_array2_fingerprint(hasher, &self.factor_b);
}
}
pub struct CompositePenaltyOp {
pub k: usize,
pub ops: Vec<Arc<dyn BetaPenaltyOp>>,
}
impl BetaPenaltyOp for CompositePenaltyOp {
fn dim(&self) -> usize {
self.k
}
fn matvec(&self, x: &[f64], y: &mut [f64]) {
for op in &self.ops {
op.matvec(x, y);
}
}
fn gradient(&self, beta: &[f64], out: &mut [f64]) {
for op in &self.ops {
op.gradient(beta, out);
}
}
fn diagonal(&self, diag: &mut [f64]) {
for op in &self.ops {
op.diagonal(diag);
}
}
fn block(&self, id: BetaBlockId, offsets: &[Range<usize>], out: &mut Array2<f64>) {
for op in &self.ops {
op.block(id, offsets, out);
}
}
fn to_dense(&self) -> Array2<f64> {
let mut out = Array2::<f64>::zeros((self.k, self.k));
for op in &self.ops {
let dense = op.to_dense();
out += &dense;
}
out
}
fn fingerprint(&self, hasher: &mut Fingerprinter) {
hasher.write_str("composite-penalty-op-v1");
hasher.write_usize(self.k);
hasher.write_usize(self.ops.len());
for op in &self.ops {
op.fingerprint(hasher);
}
}
}
pub struct MatvecDiagPenaltyOp {
k: usize,
matvec: SharedBetaMatvec,
diagonal_vec: Array1<f64>,
}
impl MatvecDiagPenaltyOp {
pub fn new(k: usize, matvec: SharedBetaMatvec, diagonal_vec: Array1<f64>) -> Self {
assert_eq!(diagonal_vec.len(), k);
Self {
k,
matvec,
diagonal_vec,
}
}
}
impl BetaPenaltyOp for MatvecDiagPenaltyOp {
fn dim(&self) -> usize {
self.k
}
fn matvec(&self, x: &[f64], y: &mut [f64]) {
let x_arr = Array1::from_iter(x.iter().copied());
let mut out = Array1::<f64>::zeros(self.k);
(self.matvec)(x_arr.view(), &mut out);
for a in 0..self.k {
y[a] += out[a];
}
}
fn gradient(&self, beta: &[f64], out: &mut [f64]) {
let beta_arr = Array1::from_iter(beta.iter().copied());
let mut hb = Array1::<f64>::zeros(self.k);
(self.matvec)(beta_arr.view(), &mut hb);
for a in 0..self.k {
out[a] += hb[a];
}
}
fn diagonal(&self, diag: &mut [f64]) {
for j in 0..self.k.min(diag.len()) {
diag[j] += self.diagonal_vec[j];
}
}
fn block(&self, id: BetaBlockId, offsets: &[Range<usize>], out: &mut Array2<f64>) {
let range = &offsets[id.0];
let b = range.end - range.start;
let mut probe = Array1::<f64>::zeros(self.k);
for bj in 0..b {
probe.fill(0.0);
probe[range.start + bj] = 1.0;
let mut col = Array1::<f64>::zeros(self.k);
(self.matvec)(probe.view(), &mut col);
for bi in 0..b {
out[[bi, bj]] += col[range.start + bi];
}
}
}
fn to_dense(&self) -> Array2<f64> {
let k = self.k;
let mut out = Array2::<f64>::zeros((k, k));
let mut probe = Array1::<f64>::zeros(k);
for j in 0..k {
probe.fill(0.0);
probe[j] = 1.0;
let mut col = Array1::<f64>::zeros(k);
(self.matvec)(probe.view(), &mut col);
for i in 0..k {
out[[i, j]] = col[i];
}
}
out
}
fn fingerprint(&self, hasher: &mut Fingerprinter) {
hasher.write_str("matvec-diag-penalty-op-v1");
hasher.write_usize(self.k);
for &value in self.diagonal_vec.iter() {
hasher.write_f64(value);
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ArrowSolverMode {
Direct,
SqrtBA,
InexactPCG,
}
impl ArrowSolverMode {
pub const fn automatic(k: usize) -> Self {
if k <= DIRECT_SOLVE_MAX_K {
Self::Direct
} else {
Self::InexactPCG
}
}
pub const fn automatic_for_single_precision(k: usize) -> Self {
if k <= DIRECT_SOLVE_MAX_K {
Self::SqrtBA
} else {
Self::InexactPCG
}
}
}
#[derive(Debug, Default, Clone, Copy, PartialEq, Eq)]
pub enum PcgStopReason {
#[default]
Converged,
MaxIter,
TrustRegion,
Indefinite,
Stagnation,
}
#[derive(Debug, Default, Clone, Copy)]
pub struct PcgDiagnostics {
pub iterations: usize,
pub matvec_calls: usize,
pub precond_apply_calls: usize,
pub ridge_escalations: usize,
pub final_relative_residual: f64,
pub stopping_reason: PcgStopReason,
}
#[derive(Debug, Clone)]
pub struct ArrowPcgOptions {
pub max_iterations: usize,
pub relative_tolerance: f64,
}
impl Default for ArrowPcgOptions {
fn default() -> Self {
Self {
max_iterations: DEFAULT_PCG_MAX_ITERATIONS,
relative_tolerance: DEFAULT_PCG_RELATIVE_TOLERANCE,
}
}
}
#[derive(Debug, Clone)]
pub struct ArrowTrustRegionOptions {
pub radius: f64,
pub steihaug_relative_tolerance: f64,
pub max_iterations: usize,
}
impl Default for ArrowTrustRegionOptions {
fn default() -> Self {
Self {
radius: DEFAULT_TRUST_REGION_RADIUS,
steihaug_relative_tolerance: DEFAULT_PCG_RELATIVE_TOLERANCE,
max_iterations: DEFAULT_PCG_MAX_ITERATIONS,
}
}
}
#[derive(Clone)]
pub struct ArrowSolveOptions {
pub mode: ArrowSolverMode,
pub pcg: ArrowPcgOptions,
pub trust_region: ArrowTrustRegionOptions,
pub streaming_chunk_size: Option<usize>,
pub riemannian_trust_region: bool,
pub gpu_matvec: Option<GpuSchurMatvec>,
}
impl std::fmt::Debug for ArrowSolveOptions {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ArrowSolveOptions")
.field("mode", &self.mode)
.field("pcg", &self.pcg)
.field("trust_region", &self.trust_region)
.field("streaming_chunk_size", &self.streaming_chunk_size)
.field("riemannian_trust_region", &self.riemannian_trust_region)
.field("gpu_matvec", &self.gpu_matvec.is_some())
.finish()
}
}
#[derive(Debug, Clone)]
pub struct ArrowProximalCorrectionOptions {
pub initial_ridge: f64,
pub ridge_growth: f64,
pub max_attempts: usize,
pub armijo_c1: f64,
pub gradient_tolerance: f64,
}
impl Default for ArrowProximalCorrectionOptions {
fn default() -> Self {
Self {
initial_ridge: DEFAULT_PROXIMAL_INITIAL_RIDGE,
ridge_growth: DEFAULT_PROXIMAL_RIDGE_GROWTH,
max_attempts: DEFAULT_PROXIMAL_MAX_ATTEMPTS,
armijo_c1: DEFAULT_ARMIJO_C1,
gradient_tolerance: DEFAULT_GRADIENT_TOLERANCE,
}
}
}
#[derive(Debug, Clone)]
pub struct ArrowAcceptedProximalStep {
pub delta_t: Array1<f64>,
pub delta_beta: Array1<f64>,
pub ridge_t: f64,
pub ridge_beta: f64,
pub proximal_ridge: f64,
pub objective_value: f64,
pub trial_objective_value: f64,
pub gradient_dot_step: f64,
pub attempts: usize,
}
impl ArrowSolveOptions {
pub fn automatic(k: usize) -> Self {
Self {
mode: ArrowSolverMode::automatic(k),
pcg: ArrowPcgOptions::default(),
trust_region: ArrowTrustRegionOptions::default(),
streaming_chunk_size: None,
riemannian_trust_region: false,
gpu_matvec: None,
}
}
pub fn direct() -> Self {
Self {
mode: ArrowSolverMode::Direct,
pcg: ArrowPcgOptions::default(),
trust_region: ArrowTrustRegionOptions::default(),
streaming_chunk_size: None,
riemannian_trust_region: false,
gpu_matvec: None,
}
}
pub fn sqrt_ba() -> Self {
Self {
mode: ArrowSolverMode::SqrtBA,
pcg: ArrowPcgOptions::default(),
trust_region: ArrowTrustRegionOptions::default(),
streaming_chunk_size: None,
riemannian_trust_region: false,
gpu_matvec: None,
}
}
pub fn inexact_pcg() -> Self {
Self {
mode: ArrowSolverMode::InexactPCG,
pcg: ArrowPcgOptions::default(),
trust_region: ArrowTrustRegionOptions::default(),
streaming_chunk_size: None,
riemannian_trust_region: false,
gpu_matvec: None,
}
}
pub fn with_streaming_chunk_size(mut self, chunk_size: Option<usize>) -> Self {
self.streaming_chunk_size = chunk_size.filter(|&chunk| chunk > 0);
self
}
}
pub trait BatchedBlockSolver {
fn factor_blocks(
&self,
rows: &[ArrowRowBlock],
ridge_t: f64,
d: usize,
) -> Result<Vec<Array2<f64>>, ArrowSchurError>;
fn solve_block_vector(&self, factor: &Array2<f64>, rhs: &Array1<f64>) -> Array1<f64>;
fn solve_block_matrix(&self, factor: &Array2<f64>, rhs: &Array2<f64>) -> Array2<f64>;
fn sqrt_solve_block_matrix(&self, factor: &Array2<f64>, rhs: &Array2<f64>) -> Array2<f64>;
fn block_gemm_subtract(&self, schur: &mut Array2<f64>, left: &Array2<f64>, right: &Array2<f64>);
}
#[derive(Debug, Clone, Copy, Default)]
pub struct CpuBatchedBlockSolver;
impl BatchedBlockSolver for CpuBatchedBlockSolver {
fn factor_blocks(
&self,
rows: &[ArrowRowBlock],
ridge_t: f64,
d: usize,
) -> Result<Vec<Array2<f64>>, ArrowSchurError> {
let mut out = Vec::with_capacity(rows.len());
for (row_idx, row) in rows.iter().enumerate() {
out.push(factor_one_row(row, ridge_t, d, row_idx)?);
}
Ok(out)
}
fn solve_block_vector(&self, factor: &Array2<f64>, rhs: &Array1<f64>) -> Array1<f64> {
chol_solve_vector(factor, rhs)
}
fn solve_block_matrix(&self, factor: &Array2<f64>, rhs: &Array2<f64>) -> Array2<f64> {
chol_solve_matrix(factor, rhs)
}
fn sqrt_solve_block_matrix(&self, factor: &Array2<f64>, rhs: &Array2<f64>) -> Array2<f64> {
lower_triangular_solve_matrix(factor, rhs)
}
fn block_gemm_subtract(
&self,
schur: &mut Array2<f64>,
left: &Array2<f64>,
right: &Array2<f64>,
) {
let k = schur.nrows();
let d = left.nrows();
assert_eq!(left.ncols(), k);
assert_eq!(right.ncols(), k);
assert_eq!(schur.ncols(), k);
for c in 0..d {
for a in 0..k {
let lca = left[[c, a]];
if lca == 0.0 {
continue;
}
for b in 0..k {
schur[[a, b]] -= lca * right[[c, b]];
}
}
}
}
}
fn factor_one_row(
row: &ArrowRowBlock,
ridge_t: f64,
d: usize,
row_idx: usize,
) -> Result<Array2<f64>, ArrowSchurError> {
if row.htt.dim() != (d, d) {
return Err(ArrowSchurError::PerRowFactorFailed {
row: row_idx,
reason: format!(
"row {row_idx} H_tt shape {:?} does not match per_point_hessian_block dimension ({d}, {d})",
row.htt.dim()
),
});
}
if row.gt.len() != d {
return Err(ArrowSchurError::PerRowFactorFailed {
row: row_idx,
reason: format!(
"row {row_idx} g_t length {} does not match latent dimension {d}",
row.gt.len()
),
});
}
let mut block = row.htt.clone();
for a in 0..d {
block[[a, a]] += ridge_t;
}
let factor = cholesky_lower(&block).map_err(|e| ArrowSchurError::PerRowFactorFailed {
row: row_idx,
reason: format!(
"row {row_idx} H_tt was non-PD at ridge_t={ridge_t}; \
cholesky error: {e}"
),
})?;
let mut min_diag = f64::INFINITY;
let mut max_diag = 0.0_f64;
for a in 0..d {
let v = factor[[a, a]];
if v < min_diag {
min_diag = v;
}
if v > max_diag {
max_diag = v;
}
}
if min_diag > 0.0 && max_diag.is_finite() {
let ratio = max_diag / min_diag;
let kappa_est = ratio * ratio;
let d_scale = (d as f64).max(1.0);
let kappa_max = 1.0 / (f64::EPSILON.sqrt() * d_scale);
if !kappa_est.is_finite() || kappa_est > kappa_max {
return Err(ArrowSchurError::PerRowFactorIllConditioned {
row: row_idx,
kappa_estimate: kappa_est,
});
}
} else {
return Err(ArrowSchurError::PerRowFactorIllConditioned {
row: row_idx,
kappa_estimate: f64::INFINITY,
});
}
Ok(factor)
}
fn manifold_mode_fingerprint(latent: &LatentCoordValues) -> u64 {
let manifold = latent.manifold();
if manifold.is_euclidean() {
return EUCLIDEAN_MANIFOLD_MODE_FINGERPRINT;
}
let mut hasher = Fingerprinter::new();
hasher.write_str("arrow-schur-manifold-mode-v1");
hasher.write_usize(latent.n_obs());
hasher.write_usize(latent.latent_dim());
write_latent_manifold(&mut hasher, manifold);
let mut metric_weights = Vec::new();
append_latent_metric_weights(&mut metric_weights, manifold);
hasher.write_usize(metric_weights.len());
for weight in metric_weights {
hasher.write_f64(weight);
}
hasher.finish_u64()
}
fn row_hessian_fingerprint_for_system(sys: &ArrowSchurSystem) -> u64 {
let mut hasher = Fingerprinter::new();
hasher.write_str("arrow-schur-row-hessian-v2");
hasher.write_usize(sys.rows.len());
hasher.write_usize(sys.d);
hasher.write_usize(sys.k);
let htbeta_op_addr: Option<usize> = sys
.htbeta_matvec
.as_ref()
.map(|op| Arc::as_ptr(op) as *const () as usize);
for row in sys.rows.iter() {
write_array2_fingerprint(&mut hasher, &row.htt);
match htbeta_op_addr {
Some(addr) => hasher.write_usize(addr),
None => write_array2_fingerprint(&mut hasher, &row.htbeta),
}
}
match sys.penalty_op.as_ref() {
Some(op) => {
hasher.write_bool(true);
op.fingerprint(&mut hasher);
}
None => {
hasher.write_bool(false);
write_array2_fingerprint(&mut hasher, &sys.hbb);
}
}
match sys.hbb_diag.as_ref() {
Some(diag) => {
hasher.write_bool(true);
hasher.write_usize(diag.len());
for &value in diag.iter() {
hasher.write_f64(value);
}
}
None => hasher.write_bool(false),
}
hasher.finish_u64()
}
fn combine_row_and_registry_fingerprints(row: u64, registry: u64) -> u64 {
if registry == 0 {
return row;
}
let mut hasher = Fingerprinter::new();
hasher.write_str("arrow-schur-row-hessian-with-penalties-v1");
hasher.write_u64(row);
hasher.write_u64(registry);
hasher.finish_u64()
}
fn stable_softplus_for_fingerprint(x: f64) -> f64 {
if x > 30.0 {
x
} else if x < -30.0 {
x.exp()
} else {
(1.0 + x.exp()).ln()
}
}
fn write_array2_fingerprint(hasher: &mut Fingerprinter, values: &Array2<f64>) {
hasher.write_usize(values.nrows());
hasher.write_usize(values.ncols());
for &value in values.iter() {
hasher.write_f64(value);
}
}
fn analytic_penalty_row_hessian_fingerprint(
penalty: &AnalyticPenaltyKind,
target_t: ArrayView1<'_, f64>,
rho_local: ArrayView1<'_, f64>,
) -> Option<u64> {
if penalty.tier() != PenaltyTier::Psi || !analytic_penalty_is_row_block_diagonal(penalty) {
return None;
}
let mut hasher = Fingerprinter::new();
hasher.write_str("arrow-schur-analytic-row-hessian-v1");
hasher.write_str(penalty.name());
hasher.write_usize(target_t.len());
hasher.write_usize(rho_local.len());
for &rho in rho_local.iter() {
hasher.write_f64(rho);
}
match penalty {
AnalyticPenaltyKind::RowPrecisionPrior(p) => {
let (n, rows, cols) = p.lambda_per_row.dim();
hasher.write_str("row-precision-fixed");
hasher.write_usize(n);
hasher.write_usize(rows);
hasher.write_usize(cols);
hasher.write_f64(p.weight);
hasher.write_bool(p.learnable_weight);
if p.learnable_weight {
hasher.write_usize(p.rho_index);
hasher.write_f64(p.weight * rho_local[p.rho_index].exp());
}
for &value in p.lambda_per_row.iter() {
hasher.write_f64(value);
}
}
AnalyticPenaltyKind::ParametricRowPrecisionPrior(p) => {
let (aux_n, aux_dim) = p.aux.dim();
let (mu_rows, mu_cols) = p.mu.dim();
let weight_offset = p.log_alpha.len() + p.raw_beta.len() + p.mu.len();
hasher.write_str("row-precision-parametric");
hasher.write_usize(aux_n);
hasher.write_usize(aux_dim);
hasher.write_usize(mu_rows);
hasher.write_usize(mu_cols);
hasher.write_f64(p.weight);
hasher.write_bool(p.learnable_weight);
for &value in p.aux.iter() {
hasher.write_f64(value);
}
for k in 0..p.log_alpha.len() {
let active_log_alpha = p.log_alpha[k] + rho_local[k];
hasher.write_f64(p.log_alpha[k]);
hasher.write_f64(active_log_alpha);
hasher.write_f64(active_log_alpha.exp());
}
let raw_beta_offset = p.log_alpha.len();
for k in 0..p.raw_beta.len() {
let active_raw_beta = p.raw_beta[k] + rho_local[raw_beta_offset + k];
hasher.write_f64(p.raw_beta[k]);
hasher.write_f64(active_raw_beta);
hasher.write_f64(stable_softplus_for_fingerprint(active_raw_beta));
}
let mu_offset = p.log_alpha.len() + p.raw_beta.len();
for k in 0..p.mu.nrows() {
for a in 0..p.mu.ncols() {
let idx = mu_offset + k * p.aux.ncols() + a;
hasher.write_f64(p.mu[[k, a]]);
hasher.write_f64(p.mu[[k, a]] + rho_local[idx]);
}
}
if p.learnable_weight {
hasher.write_usize(weight_offset);
hasher.write_f64(p.weight * rho_local[weight_offset].exp());
}
}
_ => {
hasher.write_str("row-block-diagonal");
if let Some(diag) = penalty.hessian_diag(target_t, rho_local) {
hasher.write_usize(diag.len());
for &value in diag.iter() {
hasher.write_f64(value);
}
} else {
hasher.write_usize(0);
}
}
}
Some(hasher.finish_u64())
}
fn write_latent_manifold(hasher: &mut Fingerprinter, manifold: &LatentManifold) {
match manifold {
LatentManifold::Euclidean => {
hasher.write_str("euclidean");
}
LatentManifold::Circle { period } => {
hasher.write_str("circle");
hasher.write_f64(*period);
}
LatentManifold::Sphere { dim } => {
hasher.write_str("sphere");
hasher.write_usize(*dim);
}
LatentManifold::Interval { lo, hi } => {
hasher.write_str("interval");
hasher.write_f64(*lo);
hasher.write_f64(*hi);
}
LatentManifold::Product(parts) => {
hasher.write_str("product");
hasher.write_usize(parts.len());
for part in parts {
write_latent_manifold(hasher, part);
}
}
LatentManifold::ProductWithMetric { manifolds, weights } => {
hasher.write_str("product-with-metric");
hasher.write_usize(manifolds.len());
for part in manifolds {
write_latent_manifold(hasher, part);
}
hasher.write_usize(weights.len());
for weight in weights {
hasher.write_f64(*weight);
}
}
}
}
fn append_latent_metric_weights(out: &mut Vec<f64>, manifold: &LatentManifold) {
match manifold {
LatentManifold::Euclidean => out.push(1.0),
LatentManifold::Circle { period } => {
out.push(1.0 / (period * period));
}
LatentManifold::Sphere { dim } => {
let scale = std::f64::consts::PI;
for _ in 0..*dim {
out.push(1.0 / (scale * scale));
}
}
LatentManifold::Interval { lo, hi } => {
let scale = hi - lo;
out.push(1.0 / (scale * scale));
}
LatentManifold::Product(parts) => {
for part in parts {
append_latent_metric_weights(out, part);
}
}
LatentManifold::ProductWithMetric {
manifolds: _,
weights,
} => {
out.extend(weights.iter().copied());
}
}
}
#[derive(Debug, Clone)]
pub struct ArrowRowBlock {
pub htt: Array2<f64>,
pub htbeta: Array2<f64>,
pub gt: Array1<f64>,
}
impl ArrowRowBlock {
pub fn new(d: usize, k: usize) -> Self {
Self {
htt: Array2::<f64>::zeros((d, d)),
htbeta: Array2::<f64>::zeros((d, k)),
gt: Array1::<f64>::zeros(d),
}
}
}
pub struct ArrowSchurSystem {
pub rows: Vec<ArrowRowBlock>,
pub hbb: Array2<f64>,
pub hbb_matvec: Option<SharedBetaMatvec>,
pub htbeta_matvec: Option<RowHtbetaMatvec>,
pub hbb_diag: Option<Array1<f64>>,
pub gb: Array1<f64>,
pub d: usize,
pub row_dims: Arc<[usize]>,
pub row_offsets: Arc<[usize]>,
pub k: usize,
pub manifold_mode_fingerprint: u64,
pub row_hessian_fingerprint: u64,
pub analytic_row_hessian_fingerprint: u64,
pub block_offsets: Arc<[Range<usize>]>,
pub penalty_op: Option<Arc<dyn BetaPenaltyOp>>,
}
impl ArrowSchurSystem {
pub fn new(n: usize, d: usize, k: usize) -> Self {
let rows = (0..n).map(|_| ArrowRowBlock::new(d, k)).collect();
let row_dims: Arc<[usize]> = (0..n).map(|_| d).collect::<Vec<_>>().into();
let row_offsets: Arc<[usize]> = (0..=n).map(|i| i * d).collect::<Vec<_>>().into();
let mut sys = Self {
rows,
hbb: Array2::<f64>::zeros((k, k)),
hbb_matvec: None,
htbeta_matvec: None,
hbb_diag: None,
gb: Array1::<f64>::zeros(k),
d,
row_dims,
row_offsets,
k,
manifold_mode_fingerprint: EUCLIDEAN_MANIFOLD_MODE_FINGERPRINT,
row_hessian_fingerprint: 0,
analytic_row_hessian_fingerprint: 0,
block_offsets: Arc::from([] as [Range<usize>; 0]),
penalty_op: None,
};
sys.refresh_row_hessian_fingerprint();
sys
}
pub fn new_matrix_free_shared<F>(
n: usize,
d: usize,
k: usize,
matvec: F,
diag: Array1<f64>,
) -> Self
where
F: for<'a> Fn(ArrayView1<'a, f64>, &mut Array1<f64>) + Send + Sync + 'static,
{
assert_eq!(diag.len(), k);
let rows = (0..n).map(|_| ArrowRowBlock::new(d, k)).collect();
let row_dims: Arc<[usize]> = (0..n).map(|_| d).collect::<Vec<_>>().into();
let row_offsets: Arc<[usize]> = (0..=n).map(|i| i * d).collect::<Vec<_>>().into();
let matvec_arc: SharedBetaMatvec = Arc::new(matvec);
let penalty_op: Option<Arc<dyn BetaPenaltyOp>> = Some(Arc::new(MatvecDiagPenaltyOp::new(
k,
Arc::clone(&matvec_arc),
diag.clone(),
)));
let mut sys = Self {
rows,
hbb: Array2::<f64>::zeros((0, 0)),
hbb_matvec: Some(matvec_arc),
htbeta_matvec: None,
hbb_diag: Some(diag),
gb: Array1::<f64>::zeros(k),
d,
row_dims,
row_offsets,
k,
manifold_mode_fingerprint: EUCLIDEAN_MANIFOLD_MODE_FINGERPRINT,
row_hessian_fingerprint: 0,
analytic_row_hessian_fingerprint: 0,
block_offsets: Arc::from([] as [Range<usize>; 0]),
penalty_op,
};
sys.refresh_row_hessian_fingerprint();
sys
}
pub fn new_with_per_row_dims(per_row_dims: Vec<usize>, k: usize) -> Self {
let n = per_row_dims.len();
let max_d = per_row_dims.iter().copied().max().unwrap_or(0);
let row_dims: Arc<[usize]> = per_row_dims.iter().copied().collect::<Vec<_>>().into();
let mut off_vec = Vec::with_capacity(n + 1);
let mut cursor = 0usize;
for &di in &per_row_dims {
off_vec.push(cursor);
cursor += di;
}
off_vec.push(cursor);
let row_offsets: Arc<[usize]> = off_vec.into();
let rows = per_row_dims
.iter()
.map(|&di| ArrowRowBlock::new(di, k))
.collect();
let mut sys = Self {
rows,
hbb: Array2::<f64>::zeros((k, k)),
hbb_matvec: None,
htbeta_matvec: None,
hbb_diag: None,
gb: Array1::<f64>::zeros(k),
d: max_d,
row_dims,
row_offsets,
k,
manifold_mode_fingerprint: EUCLIDEAN_MANIFOLD_MODE_FINGERPRINT,
row_hessian_fingerprint: 0,
analytic_row_hessian_fingerprint: 0,
block_offsets: Arc::from([] as [Range<usize>; 0]),
penalty_op: None,
};
sys.refresh_row_hessian_fingerprint();
sys
}
pub fn n(&self) -> usize {
self.rows.len()
}
pub fn compute_row_hessian_fingerprint(&self) -> u64 {
row_hessian_fingerprint_for_system(self)
}
pub fn current_row_hessian_fingerprint(&self) -> u64 {
combine_row_and_registry_fingerprints(
self.compute_row_hessian_fingerprint(),
self.analytic_row_hessian_fingerprint,
)
}
pub fn refresh_row_hessian_fingerprint(&mut self) {
self.row_hessian_fingerprint = self.current_row_hessian_fingerprint();
}
pub fn set_shared_beta_operator<F>(&mut self, matvec: F, diag: Array1<f64>)
where
F: for<'a> Fn(ArrayView1<'a, f64>, &mut Array1<f64>) + Send + Sync + 'static,
{
assert_eq!(diag.len(), self.k);
let matvec_arc: SharedBetaMatvec = Arc::new(matvec);
self.penalty_op = Some(Arc::new(MatvecDiagPenaltyOp::new(
self.k,
Arc::clone(&matvec_arc),
diag.clone(),
)));
self.hbb_matvec = Some(matvec_arc);
self.hbb_diag = Some(diag);
self.refresh_row_hessian_fingerprint();
}
pub fn set_row_htbeta_operator<F>(&mut self, matvec: F)
where
F: for<'a> Fn(usize, ArrayView1<'a, f64>, &mut Array1<f64>) + Send + Sync + 'static,
{
self.htbeta_matvec = Some(Arc::new(matvec));
self.refresh_row_hessian_fingerprint();
}
pub fn set_block_offsets(&mut self, offsets: Arc<[Range<usize>]>) {
self.block_offsets = offsets;
}
pub fn set_penalty_op(&mut self, op: Arc<dyn BetaPenaltyOp>) {
self.penalty_op = Some(op);
self.refresh_row_hessian_fingerprint();
}
pub fn effective_penalty_op(&self) -> Arc<dyn BetaPenaltyOp> {
match self.penalty_op.as_ref() {
Some(op) => Arc::clone(op),
None => Arc::new(DensePenaltyOp(self.hbb.clone())),
}
}
#[inline]
fn penalty_matvec_add(&self, x: &[f64], y: &mut [f64]) {
if let Some(op) = self.penalty_op.as_ref() {
op.matvec(x, y);
} else {
let k = self.hbb.nrows();
for a in 0..k {
let mut acc = 0.0_f64;
for b in 0..k {
acc += self.hbb[[a, b]] * x[b];
}
y[a] += acc;
}
}
}
#[inline]
fn penalty_diagonal_add(&self, diag: &mut [f64]) {
if let Some(op) = self.penalty_op.as_ref() {
op.diagonal(diag);
} else if let Some(hbb_diag) = self.hbb_diag.as_ref() {
let k = hbb_diag.len().min(diag.len());
for j in 0..k {
diag[j] += hbb_diag[j];
}
} else {
let k = self.hbb.nrows().min(diag.len());
for j in 0..k {
diag[j] += self.hbb[[j, j]];
}
}
}
#[inline]
fn penalty_block_add(&self, id: BetaBlockId, offsets: &[Range<usize>], out: &mut Array2<f64>) {
if let Some(op) = self.penalty_op.as_ref() {
op.block(id, offsets, out);
} else {
let range = &offsets[id.0];
let b = range.end - range.start;
if self.hbb.dim() == (self.k, self.k) {
for bi in 0..b {
for bj in 0..b {
out[[bi, bj]] += self.hbb[[range.start + bi, range.start + bj]];
}
}
} else if let Some(hbb_diag) = self.hbb_diag.as_ref() {
for bi in 0..b {
out[[bi, bi]] += hbb_diag[range.start + bi];
}
}
}
}
#[inline]
fn penalty_subblock_add(&self, cols: &[usize], out: &mut Array2<f64>) {
let b = cols.len();
if let Some(op) = self.penalty_op.as_ref() {
let mut probe = Array1::<f64>::zeros(self.k);
let mut result = Array1::<f64>::zeros(self.k);
for bj in 0..b {
probe.fill(0.0);
probe[cols[bj]] = 1.0;
result.fill(0.0);
{
let p_slice = probe.as_slice().expect("probe contiguous");
let r_slice = result.as_slice_mut().expect("result contiguous");
op.matvec(p_slice, r_slice);
}
for bi in 0..b {
out[[bi, bj]] += result[cols[bi]];
}
}
} else if self.hbb.dim() == (self.k, self.k) {
for bi in 0..b {
for bj in 0..b {
out[[bi, bj]] += self.hbb[[cols[bi], cols[bj]]];
}
}
} else if let Some(hbb_diag) = self.hbb_diag.as_ref() {
for bi in 0..b {
out[[bi, bi]] += hbb_diag[cols[bi]];
}
}
}
pub fn add_analytic_penalty_contributions(
&mut self,
registry: &AnalyticPenaltyRegistry,
target_t: ArrayView1<'_, f64>,
target_beta: ArrayView1<'_, f64>,
rho_global: ArrayView1<'_, f64>,
) -> Result<(), ArrowSchurError> {
let layout = registry.rho_layout();
let mut penalty_fingerprints = Vec::new();
for (penalty, (rho_slice, tier, name)) in registry.penalties.iter().zip(layout.iter()) {
let rho_local = rho_global.slice(ndarray::s![rho_slice.clone()]);
match tier {
PenaltyTier::Psi => {
if !analytic_penalty_is_row_block_diagonal(penalty) {
return Err(ArrowSchurError::SchurFactorFailed {
reason: format!(
"analytic penalty {name:?} couples latent rows; cross-row Hessian contributions are not yet supported on any production solver path. Consider using a row-block-only penalty (ARDPenalty, SparsityPenalty, SoftmaxAssignmentSparsity, IBPAssignment) or filing an issue requesting cross-row Hessian support."
),
});
}
self.add_ext_coord_penalty(penalty, target_t, rho_local);
if let Some(fingerprint) =
analytic_penalty_row_hessian_fingerprint(penalty, target_t, rho_local)
{
penalty_fingerprints.push(fingerprint);
}
}
PenaltyTier::Beta => {
self.add_beta_penalty(penalty, target_beta, rho_local);
}
PenaltyTier::Rho => {
}
}
}
self.analytic_row_hessian_fingerprint = if penalty_fingerprints.is_empty() {
0
} else {
let mut hasher = Fingerprinter::new();
hasher.write_str("arrow-schur-row-hessian-registry-v1");
hasher.write_usize(penalty_fingerprints.len());
for fingerprint in penalty_fingerprints {
hasher.write_u64(fingerprint);
}
hasher.finish_u64()
};
self.refresh_row_hessian_fingerprint();
Ok(())
}
pub fn apply_riemannian_latent_geometry(&mut self, latent: &LatentCoordValues) {
let manifold = latent.manifold();
self.manifold_mode_fingerprint = manifold_mode_fingerprint(latent);
if manifold.is_euclidean() {
self.refresh_row_hessian_fingerprint();
return;
}
assert_eq!(latent.n_obs(), self.rows.len());
assert_eq!(latent.latent_dim(), self.d);
for (i, row) in self.rows.iter_mut().enumerate() {
let t_i = ArrayView1::from(latent.row(i));
let gt_e = row.gt.clone();
let htt_e = row.htt.clone();
let htbeta_e = row.htbeta.clone();
row.gt = manifold.project_to_tangent(t_i, gt_e.view());
row.htt = manifold.riemannian_hessian_matrix(t_i, gt_e.view(), htt_e.view());
row.htbeta = manifold.project_matrix_columns_to_tangent(t_i, htbeta_e.view());
}
self.refresh_row_hessian_fingerprint();
}
fn add_ext_coord_penalty(
&mut self,
penalty: &AnalyticPenaltyKind,
target_t: ArrayView1<'_, f64>,
rho_local: ArrayView1<'_, f64>,
) {
let d = self.d;
let n = self.rows.len();
apply_analytic_penalty(
penalty,
target_t,
rho_local,
n * d,
d,
self,
|sys, flat, value| sys.rows[flat / d].gt[flat % d] += value,
|sys, flat, value| sys.rows[flat / d].htt[[flat % d, flat % d]] += value,
|a, probe| {
for i in 0..n {
probe[i * d + a] = 1.0;
}
},
|sys, a, hv| {
for i in 0..n {
for b in 0..d {
sys.rows[i].htt[[b, a]] += hv[i * d + b];
}
}
},
);
}
fn add_beta_penalty(
&mut self,
penalty: &AnalyticPenaltyKind,
target_beta: ArrayView1<'_, f64>,
rho_local: ArrayView1<'_, f64>,
) {
let k = self.k;
let hvp_columns = if self.hbb.dim() == (k, k) { k } else { 0 };
apply_analytic_penalty(
penalty,
target_beta,
rho_local,
k,
hvp_columns,
self,
|sys, j, value| sys.gb[j] += value,
|sys, j, value| {
if sys.hbb.dim() == (k, k) {
sys.hbb[[j, j]] += value;
}
if let Some(hbb_diag) = sys.hbb_diag.as_mut() {
hbb_diag[j] += value;
}
},
|j, probe| probe[j] = 1.0,
|sys, j, hv| {
for i in 0..k {
sys.hbb[[i, j]] += hv[i];
}
if let Some(hbb_diag) = sys.hbb_diag.as_mut() {
hbb_diag[j] += hv[j];
}
},
);
}
pub fn solve(
&self,
ridge_t: f64,
ridge_beta: f64,
) -> Result<(Array1<f64>, Array1<f64>, PcgDiagnostics), ArrowSchurError> {
let options = ArrowSolveOptions::automatic(self.k);
solve_arrow_newton_step_core(self, ridge_t, ridge_beta, &options)
}
pub fn solve_with_lm_escalation(
&self,
ridge_t: f64,
ridge_beta: f64,
) -> Result<(Array1<f64>, Array1<f64>, PcgDiagnostics), ArrowSchurError> {
let options = ArrowSolveOptions::automatic(self.k);
solve_with_lm_escalation_inner(self, ridge_t, ridge_beta, &options)
}
pub fn solve_with_options(
&self,
ridge_t: f64,
ridge_beta: f64,
options: &ArrowSolveOptions,
) -> Result<(Array1<f64>, Array1<f64>, PcgDiagnostics), ArrowSchurError> {
solve_arrow_newton_step_core(self, ridge_t, ridge_beta, options)
}
}
pub struct StreamingArrowSchur {
pub n_rows: usize,
pub d: usize,
pub row_dims: Arc<[usize]>,
pub row_offsets: Arc<[usize]>,
pub k: usize,
pub chunk_size: usize,
pub s_acc: Array2<f64>,
rhs_acc: Array1<f64>,
hbb: Array2<f64>,
gb: Array1<f64>,
row_builder: StreamingArrowRowBuilder,
}
impl std::fmt::Debug for StreamingArrowSchur {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("StreamingArrowSchur")
.field("n_rows", &self.n_rows)
.field("d", &self.d)
.field("k", &self.k)
.field("chunk_size", &self.chunk_size)
.finish_non_exhaustive()
}
}
impl StreamingArrowSchur {
#[must_use]
pub fn new(
n_rows: usize,
d: usize,
row_dims: Arc<[usize]>,
row_offsets: Arc<[usize]>,
k: usize,
hbb: Array2<f64>,
gb: Array1<f64>,
row_builder: StreamingArrowRowBuilder,
chunk_size: usize,
) -> Self {
assert_eq!(hbb.dim(), (k, k));
assert_eq!(gb.len(), k);
Self {
n_rows,
d,
row_dims,
row_offsets,
k,
chunk_size: chunk_size.max(1),
s_acc: Array2::<f64>::zeros((k, k)),
rhs_acc: Array1::<f64>::zeros(k),
hbb,
gb,
row_builder,
}
}
#[must_use]
pub fn from_system(sys: &ArrowSchurSystem, chunk_size: usize) -> Self {
let rows: Vec<ArrowRowBlock> = if sys.htbeta_matvec.is_some() {
sys.rows
.iter()
.enumerate()
.map(|(row_idx, row)| {
let htbeta = sys_htbeta_materialize_row(sys, row_idx, row);
ArrowRowBlock {
htt: row.htt.clone(),
htbeta,
gt: row.gt.clone(),
}
})
.collect()
} else {
sys.rows.clone()
};
let rows = Arc::new(rows);
let row_builder: StreamingArrowRowBuilder = Arc::new(move |row| {
rows.get(row)
.cloned()
.ok_or_else(|| ArrowSchurError::SchurFactorFailed {
reason: format!("streaming row {row} out of bounds"),
})
});
let hbb_dense = sys.effective_penalty_op().to_dense();
Self::new(
sys.rows.len(),
sys.d,
Arc::clone(&sys.row_dims),
Arc::clone(&sys.row_offsets),
sys.k,
hbb_dense,
sys.gb.clone(),
row_builder,
chunk_size,
)
}
pub fn reset_accumulator(&mut self, ridge_beta: f64) -> Result<(), ArrowSchurError> {
if self.hbb.dim() != (self.k, self.k) {
return Err(ArrowSchurError::SchurFactorFailed {
reason: "streaming Arrow-Schur requires a dense beta block accumulator".to_string(),
});
}
self.s_acc.assign(&self.hbb);
for j in 0..self.k {
self.s_acc[[j, j]] += ridge_beta;
self.rhs_acc[j] = 0.0;
}
Ok(())
}
pub fn accumulate_chunk(
&mut self,
start: usize,
end: usize,
ridge_t: f64,
mode: ArrowSolverMode,
) -> Result<(), ArrowSchurError> {
if start > end || end > self.n_rows {
return Err(ArrowSchurError::SchurFactorFailed {
reason: format!(
"streaming Arrow-Schur chunk [{start}, {end}) outside 0..{}",
self.n_rows
),
});
}
let backend = CpuBatchedBlockSolver;
for row_idx in start..end {
let row = (self.row_builder)(row_idx)?;
let di = row.htt.nrows();
self.validate_row(row_idx, &row)?;
let factor = factor_one_row(&row, ridge_t, di, row_idx)?;
let v = backend.solve_block_vector(&factor, &row.gt);
for c in 0..di {
let vc = v[c];
if vc == 0.0 {
continue;
}
for a in 0..self.k {
self.rhs_acc[a] += row.htbeta[[c, a]] * vc;
}
}
match mode {
ArrowSolverMode::Direct => {
let solved = backend.solve_block_matrix(&factor, &row.htbeta);
backend.block_gemm_subtract(&mut self.s_acc, &row.htbeta, &solved);
}
ArrowSolverMode::SqrtBA => {
let whitened = backend.sqrt_solve_block_matrix(&factor, &row.htbeta);
backend.block_gemm_subtract(&mut self.s_acc, &whitened, &whitened);
}
ArrowSolverMode::InexactPCG => {
return Err(ArrowSchurError::PcgFailed {
reason: "streaming Arrow-Schur accumulator is for dense direct modes; use matrix-free PCG without streaming_chunk_size".to_string(),
});
}
}
}
Ok(())
}
pub fn solve(
&mut self,
ridge_t: f64,
ridge_beta: f64,
options: &ArrowSolveOptions,
) -> Result<(Array1<f64>, Array1<f64>, Option<Array2<f64>>), ArrowSchurError> {
self.reset_accumulator(ridge_beta)?;
for start in (0..self.n_rows).step_by(self.chunk_size) {
let end = (start + self.chunk_size).min(self.n_rows);
self.accumulate_chunk(start, end, ridge_t, options.mode)?;
}
for j in 0..self.k {
self.rhs_acc[j] -= self.gb[j];
}
symmetrize_upper_from_lower(&mut self.s_acc);
let trust_metric_weights = None;
let (delta_beta, schur_factor, _diag) =
solve_dense_reduced_system(&self.s_acc, &self.rhs_acc, options, trust_metric_weights)?;
let delta_t = self.back_substitute(ridge_t, delta_beta.view())?;
Ok((delta_t, delta_beta, schur_factor))
}
fn back_substitute(
&self,
ridge_t: f64,
delta_beta: ArrayView1<'_, f64>,
) -> Result<Array1<f64>, ArrowSchurError> {
let backend = CpuBatchedBlockSolver;
let total_len = self.row_offsets[self.n_rows];
let mut delta_t = Array1::<f64>::zeros(total_len);
let mut rhs = Array1::<f64>::zeros(self.d);
for start in (0..self.n_rows).step_by(self.chunk_size) {
let end = (start + self.chunk_size).min(self.n_rows);
for row_idx in start..end {
let row = (self.row_builder)(row_idx)?;
let di = row.htt.nrows();
self.validate_row(row_idx, &row)?;
let factor = factor_one_row(&row, ridge_t, di, row_idx)?;
for c in 0..di {
let mut acc = row.gt[c];
for a in 0..self.k {
acc += row.htbeta[[c, a]] * delta_beta[a];
}
rhs[c] = acc;
}
let dt_i = backend.solve_block_vector(&factor, &rhs);
let row_base = self.row_offsets[row_idx];
for c in 0..di {
delta_t[row_base + c] = -dt_i[c];
}
}
}
Ok(delta_t)
}
fn validate_row(&self, row_idx: usize, row: &ArrowRowBlock) -> Result<(), ArrowSchurError> {
let expected_di = if row_idx < self.row_dims.len() {
self.row_dims[row_idx]
} else {
self.d
};
let actual_di = row.htt.nrows();
if actual_di != expected_di || row.htt.ncols() != expected_di {
return Err(ArrowSchurError::PerRowFactorFailed {
row: row_idx,
reason: format!(
"streaming row H_tt shape {:?} != ({expected_di}, {expected_di})",
row.htt.dim(),
),
});
}
if row.htbeta.dim() != (expected_di, self.k) {
return Err(ArrowSchurError::SchurFactorFailed {
reason: format!(
"streaming row H_tβ shape {:?} != ({expected_di}, {})",
row.htbeta.dim(),
self.k
),
});
}
if row.gt.len() != expected_di {
return Err(ArrowSchurError::PerRowFactorFailed {
row: row_idx,
reason: format!("streaming row g_t length {} != {expected_di}", row.gt.len()),
});
}
Ok::<(), _>(())
}
}
fn apply_analytic_penalty<S, G, D, P, H>(
penalty: &AnalyticPenaltyKind,
target: ArrayView1<'_, f64>,
rho_local: ArrayView1<'_, f64>,
expected_target_len: usize,
hvp_columns: usize,
scatter_target: &mut S,
mut grad_scatter: G,
mut diag_scatter: D,
seed_hvp_probe: P,
mut hvp_column_scatter: H,
) where
G: FnMut(&mut S, usize, f64),
D: FnMut(&mut S, usize, f64),
P: Fn(usize, &mut Array1<f64>),
H: for<'a> FnMut(&mut S, usize, ArrayView1<'a, f64>),
{
assert_eq!(target.len(), expected_target_len);
let grad = penalty.grad_target(target, rho_local);
for index in 0..expected_target_len {
grad_scatter(scatter_target, index, grad[index]);
}
if let Some(diag) = penalty.hessian_diag(target, rho_local) {
assert_eq!(diag.len(), expected_target_len);
for index in 0..expected_target_len {
diag_scatter(scatter_target, index, diag[index]);
}
return;
}
let mut probe = Array1::<f64>::zeros(expected_target_len);
for column in 0..hvp_columns {
probe.fill(0.0);
seed_hvp_probe(column, &mut probe);
let hv = penalty.hvp(target, rho_local, probe.view());
hvp_column_scatter(scatter_target, column, hv.view());
}
}
fn analytic_penalty_is_row_block_diagonal(penalty: &AnalyticPenaltyKind) -> bool {
penalty.is_row_block_diagonal()
}
#[derive(Clone)]
pub enum ArrowUndampedFactors {
SameAsDamped,
Owned(Arc<[Array2<f64>]>),
}
impl std::fmt::Debug for ArrowUndampedFactors {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::SameAsDamped => f.write_str("SameAsDamped"),
Self::Owned(factors) => f.debug_tuple("Owned").field(&factors.len()).finish(),
}
}
}
fn sys_htbeta_apply_row(
sys: &ArrowSchurSystem,
row_idx: usize,
row: &ArrowRowBlock,
x: ArrayView1<'_, f64>,
out: &mut Array1<f64>,
) {
if let Some(op) = sys.htbeta_matvec.as_ref() {
op(row_idx, x, out);
} else {
let di = row.htbeta.nrows();
let k = sys.k;
for c in 0..di {
let mut acc = 0.0_f64;
for a in 0..k {
acc += row.htbeta[[c, a]] * x[a];
}
out[c] = acc;
}
}
}
fn sys_htbeta_accumulate_transpose(
sys: &ArrowSchurSystem,
row_idx: usize,
row: &ArrowRowBlock,
v: ArrayView1<'_, f64>,
out: &mut Array1<f64>,
) {
if let Some(op) = sys.htbeta_matvec.as_ref() {
let di = v.len();
htbeta_probe_transpose(row_idx, op, v, out, di, sys.k);
} else {
let di = row.htbeta.nrows();
let k = sys.k;
for c in 0..di {
let vc = v[c];
if vc == 0.0 {
continue;
}
for a in 0..k {
out[a] += row.htbeta[[c, a]] * vc;
}
}
}
}
fn sys_htbeta_materialize_row(
sys: &ArrowSchurSystem,
row_idx: usize,
row: &ArrowRowBlock,
) -> Array2<f64> {
let di = sys.row_dims[row_idx];
let k = sys.k;
if row.htbeta.dim() == (di, k) {
return row.htbeta.clone();
}
let op = sys.htbeta_matvec.as_ref().unwrap_or_else(|| {
panic!(
"row {row_idx}: htbeta shape {:?} != ({di}, {k}) and no htbeta_matvec installed",
row.htbeta.dim()
)
});
let mut mat = Array2::<f64>::zeros((di, k));
let mut e_a = Array1::<f64>::zeros(k);
let mut col = Array1::<f64>::zeros(di);
for a in 0..k {
e_a.fill(0.0);
e_a[a] = 1.0;
col.fill(0.0);
op(row_idx, e_a.view(), &mut col);
for c in 0..di {
mat[[c, a]] = col[c];
}
}
mat
}
fn htbeta_probe_transpose(
row: usize,
op: &RowHtbetaMatvec,
v: ArrayView1<'_, f64>,
out: &mut Array1<f64>,
d: usize,
k: usize,
) {
let mut e_a = Array1::<f64>::zeros(k);
let mut col_a = Array1::<f64>::zeros(d);
for a in 0..k {
e_a.fill(0.0);
e_a[a] = 1.0;
col_a.fill(0.0);
op(row, e_a.view(), &mut col_a);
let mut acc = 0.0_f64;
for c in 0..d {
acc += col_a[c] * v[c];
}
out[a] += acc;
}
}
#[derive(Clone)]
pub enum ArrowHtbetaCache {
Dense {
blocks: Arc<[Array2<f64>]>,
estimated_bytes: usize,
},
Matvec {
op: RowHtbetaMatvec,
estimated_bytes: usize,
},
Disabled {
estimated_bytes: usize,
},
}
impl std::fmt::Debug for ArrowHtbetaCache {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Dense {
blocks,
estimated_bytes,
} => f
.debug_struct("Dense")
.field("blocks", &blocks.len())
.field("estimated_bytes", estimated_bytes)
.finish(),
Self::Matvec {
estimated_bytes, ..
} => f
.debug_struct("Matvec")
.field("estimated_bytes", estimated_bytes)
.finish(),
Self::Disabled { estimated_bytes } => f
.debug_struct("Disabled")
.field("estimated_bytes", estimated_bytes)
.finish(),
}
}
}
impl ArrowHtbetaCache {
fn is_available(&self) -> bool {
!matches!(self, Self::Disabled { .. })
}
fn apply_row(
&self,
row: usize,
delta_beta: ArrayView1<'_, f64>,
out: &mut Array1<f64>,
) -> bool {
match self {
Self::Dense { blocks, .. } => {
let Some(block) = blocks.get(row) else {
return false;
};
if block.ncols() != delta_beta.len() || block.nrows() != out.len() {
return false;
}
for c in 0..block.nrows() {
let mut acc = 0.0_f64;
for a in 0..block.ncols() {
acc += block[[c, a]] * delta_beta[a];
}
out[c] = acc;
}
true
}
Self::Matvec { op, .. } => {
op(row, delta_beta, out);
true
}
Self::Disabled { .. } => false,
}
}
fn apply_row_transpose_accumulate(
&self,
row: usize,
v: ArrayView1<'_, f64>,
out: &mut Array1<f64>,
d: usize,
k: usize,
fallback_op: Option<&RowHtbetaMatvec>,
) -> bool {
match self {
Self::Dense { blocks, .. } => {
let Some(block) = blocks.get(row) else {
return false;
};
if block.nrows() != v.len() || block.ncols() != out.len() {
return false;
}
for c in 0..block.nrows() {
let vc = v[c];
if vc == 0.0 {
continue;
}
for a in 0..block.ncols() {
out[a] += block[[c, a]] * vc;
}
}
true
}
Self::Matvec { op, .. } => {
htbeta_probe_transpose(row, op, v, out, d, k);
true
}
Self::Disabled { .. } => {
if let Some(op) = fallback_op {
htbeta_probe_transpose(row, op, v, out, d, k);
true
} else {
false
}
}
}
}
}
#[derive(Debug, Clone)]
pub struct ArrowFactorCache {
pub htt_factors: Arc<[Array2<f64>]>,
pub htt_factors_undamped: ArrowUndampedFactors,
pub schur_factor: Option<Array2<f64>>,
pub solver_mode: ArrowSolverMode,
pub ridge_t: f64,
pub ridge_beta: f64,
pub htbeta: ArrowHtbetaCache,
pub d: usize,
pub row_dims: Arc<[usize]>,
pub row_offsets: Arc<[usize]>,
pub k: usize,
pub manifold_mode_fingerprint: u64,
pub row_hessian_fingerprint: u64,
pub pcg_diagnostics: PcgDiagnostics,
}
impl ArrowFactorCache {
pub fn n_rows(&self) -> usize {
self.htt_factors.len()
}
pub fn htbeta_available(&self) -> bool {
self.htbeta.is_available()
}
pub fn undamped_factor(&self, row: usize) -> &Array2<f64> {
match &self.htt_factors_undamped {
ArrowUndampedFactors::SameAsDamped => &self.htt_factors[row],
ArrowUndampedFactors::Owned(factors) => &factors[row],
}
}
pub fn undamped_factor_count(&self) -> usize {
match &self.htt_factors_undamped {
ArrowUndampedFactors::SameAsDamped => self.htt_factors.len(),
ArrowUndampedFactors::Owned(factors) => factors.len(),
}
}
pub fn undamped_factors_iter(&self) -> impl Iterator<Item = &Array2<f64>> {
(0..self.undamped_factor_count()).map(|row| self.undamped_factor(row))
}
pub fn delta_t_len(&self) -> usize {
self.row_offsets[self.n_rows()]
}
pub fn apply_htbeta_row(
&self,
row: usize,
delta_beta: ArrayView1<'_, f64>,
out: &mut Array1<f64>,
) -> bool {
let di = if row < self.row_dims.len() {
self.row_dims[row]
} else {
self.d
};
if out.len() != di || delta_beta.len() != self.k {
return false;
}
self.htbeta.apply_row(row, delta_beta, out)
}
pub fn apply_htbeta_row_transpose(
&self,
row: usize,
v: ArrayView1<'_, f64>,
out: &mut Array1<f64>,
fallback_op: Option<&RowHtbetaMatvec>,
) -> bool {
let di = if row < self.row_dims.len() {
self.row_dims[row]
} else {
self.d
};
if v.len() != di || out.len() != self.k {
return false;
}
self.htbeta
.apply_row_transpose_accumulate(row, v, out, di, self.k, fallback_op)
}
pub fn predict_delta_t_from_delta_beta(&self, delta_beta: ArrayView1<'_, f64>) -> Array1<f64> {
let n = self.undamped_factor_count();
let total_len = self.delta_t_len();
assert_eq!(delta_beta.len(), self.k);
if !self.htbeta_available() {
return Array1::<f64>::zeros(total_len);
}
let mut out = Array1::<f64>::zeros(total_len);
let mut rhs = Array1::<f64>::zeros(self.d);
for i in 0..n {
let di = self.row_dims[i];
rhs.fill(0.0);
let rhs_i = rhs.slice_mut(ndarray::s![..di]);
let mut rhs_slice = rhs_i.to_owned();
if !self.apply_htbeta_row(i, delta_beta.view(), &mut rhs_slice) {
return Array1::<f64>::zeros(total_len);
}
let v = chol_solve_vector(self.undamped_factor(i), &rhs_slice);
let row_base = self.row_offsets[i];
for c in 0..di {
out[row_base + c] = -v[c];
}
}
out
}
pub fn predict_delta_t_combined(
&self,
delta_beta: Option<ArrayView1<'_, f64>>,
delta_gt: Option<ArrayView1<'_, f64>>,
) -> Array1<f64> {
let n = self.undamped_factor_count();
let total_len = self.delta_t_len();
if let Some(db) = delta_beta.as_ref() {
assert_eq!(db.len(), self.k);
}
if let Some(dg) = delta_gt.as_ref() {
assert_eq!(dg.len(), total_len);
}
let mut out = Array1::<f64>::zeros(total_len);
let mut rhs = Array1::<f64>::zeros(self.d);
let mut htbeta_delta = Array1::<f64>::zeros(self.d);
for i in 0..n {
let di = self.row_dims[i];
let row_base = self.row_offsets[i];
for c in 0..di {
rhs[c] = 0.0;
}
if let Some(db) = delta_beta.as_ref() {
for c in 0..di {
htbeta_delta[c] = 0.0;
}
let mut htbeta_slice = htbeta_delta.slice_mut(ndarray::s![..di]).to_owned();
if !self.apply_htbeta_row(i, db.view(), &mut htbeta_slice) {
return Array1::<f64>::zeros(total_len);
}
for c in 0..di {
rhs[c] += htbeta_slice[c];
}
}
if let Some(dg) = delta_gt.as_ref() {
for c in 0..di {
rhs[c] += dg[row_base + c];
}
}
let rhs_slice = rhs.slice(ndarray::s![..di]).to_owned();
let v = chol_solve_vector(self.undamped_factor(i), &rhs_slice);
for c in 0..di {
out[row_base + c] = -v[c];
}
}
out
}
pub fn arrow_log_det(&self) -> (f64, Option<f64>) {
let mut log_det_tt = 0.0_f64;
for l in self.htt_factors.iter() {
for i in 0..l.nrows() {
log_det_tt += l[[i, i]].ln();
}
}
log_det_tt *= 2.0;
let log_det_schur = self.schur_factor.as_ref().map(|l| {
let mut s = 0.0_f64;
for i in 0..l.nrows() {
s += l[[i, i]].ln();
}
2.0 * s
});
(log_det_tt, log_det_schur)
}
pub fn predict_delta_t_from_delta_gt(&self, delta_gt: ArrayView1<'_, f64>) -> Array1<f64> {
let n = self.undamped_factor_count();
let total_len = self.delta_t_len();
assert_eq!(delta_gt.len(), total_len);
assert_eq!(
self.undamped_factor_count(),
n,
"undamped factor cache and N must agree"
);
let mut out = Array1::<f64>::zeros(total_len);
for i in 0..n {
let di = self.row_dims[i];
let row_base = self.row_offsets[i];
let rhs = delta_gt
.slice(ndarray::s![row_base..row_base + di])
.to_owned();
let v = chol_solve_vector(self.undamped_factor(i), &rhs);
for c in 0..di {
out[row_base + c] = -v[c];
}
}
out
}
}
pub fn solve_arrow_newton_step_with_options(
sys: &ArrowSchurSystem,
ridge_t: f64,
ridge_beta: f64,
options: &ArrowSolveOptions,
) -> Result<(Array1<f64>, Array1<f64>, ArrowFactorCache), ArrowSchurError> {
if options.streaming_chunk_size.is_some() {
return Err(ArrowSchurError::SchurFactorFailed {
reason: "streaming Arrow-Schur solve does not materialize the factor cache required by this entry point".to_string(),
});
}
let step = solve_arrow_newton_step_artifacts(sys, ridge_t, ridge_beta, options)?;
let backend = CpuBatchedBlockSolver;
let htbeta_estimated_bytes =
estimated_htbeta_bytes(sys.rows.len(), sys.d, sys.k).unwrap_or(usize::MAX);
let htbeta = if let Some(op) = sys.htbeta_matvec.as_ref() {
ArrowHtbetaCache::Matvec {
op: Arc::clone(op),
estimated_bytes: htbeta_estimated_bytes,
}
} else if htbeta_estimated_bytes <= ARROW_FACTOR_CACHE_HTBETA_BUDGET_BYTES {
ArrowHtbetaCache::Dense {
blocks: sys
.rows
.iter()
.map(|r| r.htbeta.clone())
.collect::<Vec<_>>()
.into(),
estimated_bytes: htbeta_estimated_bytes,
}
} else {
ArrowHtbetaCache::Disabled {
estimated_bytes: htbeta_estimated_bytes,
}
};
let htt_factors = Arc::<[Array2<f64>]>::from(step.htt_factors);
let htt_factors_undamped = if ridge_t == 0.0 {
ArrowUndampedFactors::SameAsDamped
} else {
ArrowUndampedFactors::Owned(backend.factor_blocks(&sys.rows, 0.0, sys.d)?.into())
};
let cache = ArrowFactorCache {
htt_factors,
htt_factors_undamped,
schur_factor: step.schur_factor,
solver_mode: options.mode,
ridge_t,
ridge_beta,
htbeta,
d: sys.d,
row_dims: Arc::clone(&sys.row_dims),
row_offsets: Arc::clone(&sys.row_offsets),
k: sys.k,
manifold_mode_fingerprint: sys.manifold_mode_fingerprint,
row_hessian_fingerprint: sys.current_row_hessian_fingerprint(),
pcg_diagnostics: step.pcg_diagnostics,
};
Ok((step.delta_t, step.delta_beta, cache))
}
fn estimated_htbeta_bytes(n: usize, d: usize, k: usize) -> Option<usize> {
n.checked_mul(d)?
.checked_mul(k)?
.checked_mul(std::mem::size_of::<f64>())
}
pub fn solve_arrow_newton_step_core(
sys: &ArrowSchurSystem,
ridge_t: f64,
ridge_beta: f64,
options: &ArrowSolveOptions,
) -> Result<(Array1<f64>, Array1<f64>, PcgDiagnostics), ArrowSchurError> {
if let Some(chunk_size) = options.streaming_chunk_size {
let mut streaming = StreamingArrowSchur::from_system(sys, chunk_size);
return streaming
.solve(ridge_t, ridge_beta, options)
.map(|(delta_t, delta_beta, _)| (delta_t, delta_beta, PcgDiagnostics::default()));
}
solve_arrow_newton_step_artifacts(sys, ridge_t, ridge_beta, options)
.map(|step| (step.delta_t, step.delta_beta, step.pcg_diagnostics))
}
pub fn solve_with_lm_escalation_inner(
sys: &ArrowSchurSystem,
ridge_t: f64,
ridge_beta: f64,
options: &ArrowSolveOptions,
) -> Result<(Array1<f64>, Array1<f64>, PcgDiagnostics), ArrowSchurError> {
let mut proximal_ridge = 0.0_f64;
let mut escalations: usize = 0;
let mut last_err: Option<ArrowSchurError> = None;
for attempt in 0..=DEFAULT_PROXIMAL_MAX_ATTEMPTS {
let damped_ridge_t = ridge_t + proximal_ridge;
let damped_ridge_beta = ridge_beta + proximal_ridge;
match solve_arrow_newton_step_artifacts(sys, damped_ridge_t, damped_ridge_beta, options) {
Ok(mut step) => {
step.pcg_diagnostics.ridge_escalations = escalations;
return Ok((step.delta_t, step.delta_beta, step.pcg_diagnostics));
}
Err(err) => {
let recoverable = matches!(
err,
ArrowSchurError::PerRowFactorFailed { .. }
| ArrowSchurError::PerRowFactorIllConditioned { .. }
| ArrowSchurError::SchurFactorFailed { .. }
);
last_err = Some(err);
if !recoverable {
break;
}
if attempt == DEFAULT_PROXIMAL_MAX_ATTEMPTS {
break;
}
proximal_ridge = if proximal_ridge == 0.0 {
DEFAULT_PROXIMAL_INITIAL_RIDGE
} else {
proximal_ridge * DEFAULT_PROXIMAL_RIDGE_GROWTH
};
escalations += 1;
}
}
}
Err(last_err.expect("escalation loop set last_err on failure"))
}
pub fn solve_arrow_newton_step_with_proximal_correction<F>(
sys: &ArrowSchurSystem,
base_ridge_t: f64,
base_ridge_beta: f64,
current_objective_value: f64,
options: &ArrowSolveOptions,
correction: &ArrowProximalCorrectionOptions,
mut trial_objective: F,
) -> Result<ArrowAcceptedProximalStep, ArrowSchurError>
where
F: for<'a, 'b> FnMut(ArrayView1<'a, f64>, ArrayView1<'b, f64>) -> f64,
{
if !current_objective_value.is_finite() {
return Err(ArrowSchurError::AdaptiveCorrectionFailed {
reason: "current objective is not finite".to_string(),
});
}
if !(correction.ridge_growth.is_finite() && correction.ridge_growth > 1.0) {
return Err(ArrowSchurError::AdaptiveCorrectionFailed {
reason: format!(
"ridge_growth must be finite and > 1; got {}",
correction.ridge_growth
),
});
}
if !(correction.armijo_c1.is_finite()
&& correction.armijo_c1 > 0.0
&& correction.armijo_c1 < 1.0)
{
return Err(ArrowSchurError::AdaptiveCorrectionFailed {
reason: format!("armijo_c1 must be in (0, 1); got {}", correction.armijo_c1),
});
}
let grad_norm = arrow_gradient_norm(sys);
if grad_norm <= correction.gradient_tolerance.max(0.0) {
return Ok(ArrowAcceptedProximalStep {
delta_t: Array1::<f64>::zeros(sys.row_offsets[sys.rows.len()]),
delta_beta: Array1::<f64>::zeros(sys.k),
ridge_t: base_ridge_t,
ridge_beta: base_ridge_beta,
proximal_ridge: 0.0,
objective_value: current_objective_value,
trial_objective_value: current_objective_value,
gradient_dot_step: 0.0,
attempts: 0,
});
}
let mut proximal_ridge = correction.initial_ridge.max(0.0);
let mut last_reason = String::from("no attempts were made");
for attempt in 0..correction.max_attempts {
let ridge_t = base_ridge_t + proximal_ridge;
let ridge_beta = base_ridge_beta + proximal_ridge;
match solve_arrow_newton_step_core(sys, ridge_t, ridge_beta, options) {
Ok((delta_t, delta_beta, _diag)) => {
let g_dot_p = arrow_gradient_dot_step(sys, delta_t.view(), delta_beta.view());
if !(g_dot_p.is_finite() && g_dot_p < 0.0) {
last_reason =
format!("candidate was not a finite descent direction: g·p={g_dot_p}");
} else {
let trial_value = trial_objective(delta_t.view(), delta_beta.view());
let armijo_bound = current_objective_value + correction.armijo_c1 * g_dot_p;
if trial_value.is_finite() && trial_value <= armijo_bound {
return Ok(ArrowAcceptedProximalStep {
delta_t,
delta_beta,
ridge_t,
ridge_beta,
proximal_ridge,
objective_value: current_objective_value,
trial_objective_value: trial_value,
gradient_dot_step: g_dot_p,
attempts: attempt + 1,
});
}
last_reason = format!(
"Armijo rejected trial objective {trial_value}; bound {armijo_bound}"
);
}
}
Err(err) => {
last_reason = err.to_string();
}
}
proximal_ridge = next_proximal_ridge(proximal_ridge, correction.ridge_growth);
}
Err(ArrowSchurError::AdaptiveCorrectionFailed {
reason: format!(
"failed after {} attempts; last rejection: {last_reason}",
correction.max_attempts
),
})
}
pub fn arrow_damped_quadratic_model_reduction(
sys: &ArrowSchurSystem,
delta_t: ArrayView1<'_, f64>,
delta_beta: ArrayView1<'_, f64>,
ridge_t: f64,
ridge_beta: f64,
) -> Result<f64, ArrowSchurError> {
let total_len = sys.row_offsets[sys.rows.len()];
assert_eq!(delta_t.len(), total_len);
assert_eq!(delta_beta.len(), sys.k);
let mut lin = sys.gb.dot(&delta_beta);
let mut quad = ridge_beta * delta_beta.dot(&delta_beta);
let mut hbb_delta = Array1::<f64>::zeros(sys.k);
{
let x_slice = delta_beta
.as_slice()
.expect("delta_beta must be contiguous");
let y_slice = hbb_delta
.as_slice_mut()
.expect("hbb_delta must be contiguous");
sys.penalty_matvec_add(x_slice, y_slice);
}
quad += delta_beta.dot(&hbb_delta);
let mut htbeta_x = Array1::<f64>::zeros(sys.d);
for (i, row) in sys.rows.iter().enumerate() {
let di = sys.row_dims[i];
let row_base = sys.row_offsets[i];
let mut htbeta_x_i = htbeta_x.slice_mut(ndarray::s![..di]).to_owned();
htbeta_x_i.fill(0.0);
sys_htbeta_apply_row(sys, i, row, delta_beta, &mut htbeta_x_i);
for c in 0..di {
let dt_c = delta_t[row_base + c];
lin += row.gt[c] * dt_c;
quad += ridge_t * dt_c * dt_c;
for r in 0..di {
quad += dt_c * row.htt[[c, r]] * delta_t[row_base + r];
}
quad += 2.0 * dt_c * htbeta_x_i[c];
}
}
Ok(-(lin + 0.5 * quad))
}
pub fn arrow_bare_quadratic_model_reduction(
sys: &ArrowSchurSystem,
delta_t: ArrayView1<'_, f64>,
delta_beta: ArrayView1<'_, f64>,
ridge_t: f64,
ridge_beta: f64,
) -> Result<f64, ArrowSchurError> {
let damped =
arrow_damped_quadratic_model_reduction(sys, delta_t, delta_beta, ridge_t, ridge_beta)?;
let ridge_beta_contrib = 0.5 * ridge_beta * delta_beta.dot(&delta_beta);
let ridge_t_contrib = {
let mut acc = 0.0_f64;
for v in delta_t.iter() {
acc += v * v;
}
0.5 * ridge_t * acc
};
Ok(damped + ridge_beta_contrib + ridge_t_contrib)
}
fn next_proximal_ridge(current: f64, growth: f64) -> f64 {
if current > 0.0 {
current * growth
} else {
DEFAULT_PROXIMAL_INITIAL_RIDGE
}
}
fn arrow_gradient_norm(sys: &ArrowSchurSystem) -> f64 {
let mut sum = 0.0;
for row in sys.rows.iter() {
for &v in row.gt.iter() {
sum += v * v;
}
}
for &v in sys.gb.iter() {
sum += v * v;
}
sum.sqrt()
}
fn arrow_gradient_dot_step(
sys: &ArrowSchurSystem,
delta_t: ArrayView1<'_, f64>,
delta_beta: ArrayView1<'_, f64>,
) -> f64 {
assert_eq!(delta_t.len(), sys.row_offsets[sys.rows.len()]);
assert_eq!(delta_beta.len(), sys.k);
let mut out = 0.0;
for (i, row) in sys.rows.iter().enumerate() {
let di = sys.row_dims[i];
let row_base = sys.row_offsets[i];
for c in 0..di {
out += row.gt[c] * delta_t[row_base + c];
}
}
for a in 0..sys.k {
out += sys.gb[a] * delta_beta[a];
}
out
}
struct ArrowNewtonStepArtifacts {
delta_t: Array1<f64>,
delta_beta: Array1<f64>,
htt_factors: Vec<Array2<f64>>,
schur_factor: Option<Array2<f64>>,
pcg_diagnostics: PcgDiagnostics,
}
fn solve_arrow_newton_step_artifacts(
sys: &ArrowSchurSystem,
ridge_t: f64,
ridge_beta: f64,
options: &ArrowSolveOptions,
) -> Result<ArrowNewtonStepArtifacts, ArrowSchurError> {
if let Some(chunk_size) = options.streaming_chunk_size {
let mut streaming = StreamingArrowSchur::from_system(sys, chunk_size);
let (delta_t, delta_beta, schur_factor) = streaming.solve(ridge_t, ridge_beta, options)?;
return Ok(ArrowNewtonStepArtifacts {
delta_t,
delta_beta,
htt_factors: Vec::new(),
schur_factor,
pcg_diagnostics: PcgDiagnostics::default(),
});
}
let n = sys.rows.len();
let backend = CpuBatchedBlockSolver;
let htt_factors = backend.factor_blocks(&sys.rows, ridge_t, sys.d)?;
let rhs_beta = reduced_rhs_beta(sys, &htt_factors, &backend);
let trust_metric_weights = None;
let (delta_beta, schur_factor, pcg_diagnostics) = match options.mode {
ArrowSolverMode::Direct => {
let schur = build_dense_schur_direct(sys, &htt_factors, ridge_beta, &backend)?;
let (db, sf, diag) =
solve_dense_reduced_system(&schur, &rhs_beta, options, trust_metric_weights)?;
(db, sf, diag)
}
ArrowSolverMode::SqrtBA => {
let schur = build_dense_schur_sqrt_ba(sys, &htt_factors, ridge_beta, &backend)?;
let (db, sf, diag) =
solve_dense_reduced_system(&schur, &rhs_beta, options, trust_metric_weights)?;
(db, sf, diag)
}
ArrowSolverMode::InexactPCG => {
let (delta, diag) = steihaug_pcg_auto(
sys,
&htt_factors,
ridge_beta,
&rhs_beta,
&options.pcg,
&options.trust_region,
&backend,
options.gpu_matvec.as_ref(),
trust_metric_weights,
)?;
(delta, None, diag)
}
};
let total_dt_len = sys.row_offsets[n];
let mut delta_t = Array1::<f64>::zeros(total_dt_len);
let mut rhs = Array1::<f64>::zeros(sys.d);
let mut htbeta_delta = Array1::<f64>::zeros(sys.d);
for i in 0..n {
let di = sys.row_dims[i];
let row_base = sys.row_offsets[i];
assert_eq!(sys.rows[i].gt.len(), di);
for c in 0..di {
htbeta_delta[c] = 0.0;
}
let mut htbeta_slice = htbeta_delta.slice_mut(ndarray::s![..di]).to_owned();
sys_htbeta_apply_row(sys, i, &sys.rows[i], delta_beta.view(), &mut htbeta_slice);
let mut rhs_i = rhs.slice_mut(ndarray::s![..di]);
for c in 0..di {
rhs_i[c] = sys.rows[i].gt[c] + htbeta_slice[c];
}
drop(rhs_i);
let rhs_slice = rhs.slice(ndarray::s![..di]).to_owned();
let dt_i = backend.solve_block_vector(&htt_factors[i], &rhs_slice);
for c in 0..di {
delta_t[row_base + c] = -dt_i[c];
}
}
Ok(ArrowNewtonStepArtifacts {
delta_t,
delta_beta,
htt_factors,
schur_factor,
pcg_diagnostics,
})
}
fn reduced_rhs_beta<B: BatchedBlockSolver>(
sys: &ArrowSchurSystem,
htt_factors: &[Array2<f64>],
backend: &B,
) -> Array1<f64> {
let k = sys.k;
let mut rhs_beta = Array1::<f64>::zeros(k);
for (i, row) in sys.rows.iter().enumerate() {
let v = backend.solve_block_vector(&htt_factors[i], &row.gt);
sys_htbeta_accumulate_transpose(sys, i, row, v.view(), &mut rhs_beta);
}
for j in 0..k {
rhs_beta[j] -= sys.gb[j];
}
rhs_beta
}
fn build_dense_schur_direct<B: BatchedBlockSolver>(
sys: &ArrowSchurSystem,
htt_factors: &[Array2<f64>],
ridge_beta: f64,
backend: &B,
) -> Result<Array2<f64>, ArrowSchurError> {
let k = sys.k;
let op = sys.effective_penalty_op();
if op.dim() != k {
return Err(ArrowSchurError::SchurFactorFailed {
reason: "Direct BA requires a K×K shared H_ββ penalty operator".to_string(),
});
}
let mut schur = op.to_dense();
for j in 0..k {
schur[[j, j]] += ridge_beta;
}
for (i, row) in sys.rows.iter().enumerate() {
let htbeta = sys_htbeta_materialize_row(sys, i, row);
let solved = backend.solve_block_matrix(&htt_factors[i], &htbeta);
backend.block_gemm_subtract(&mut schur, &htbeta, &solved);
}
symmetrize_upper_from_lower(&mut schur);
Ok(schur)
}
fn build_dense_schur_sqrt_ba<B: BatchedBlockSolver>(
sys: &ArrowSchurSystem,
htt_factors: &[Array2<f64>],
ridge_beta: f64,
backend: &B,
) -> Result<Array2<f64>, ArrowSchurError> {
let k = sys.k;
let op = sys.effective_penalty_op();
if op.dim() != k {
return Err(ArrowSchurError::SchurFactorFailed {
reason: "Square-Root BA direct solve requires a K×K shared H_ββ penalty operator"
.to_string(),
});
}
let mut schur = op.to_dense();
for j in 0..k {
schur[[j, j]] += ridge_beta;
}
for (i, row) in sys.rows.iter().enumerate() {
let htbeta = sys_htbeta_materialize_row(sys, i, row);
let whitened = backend.sqrt_solve_block_matrix(&htt_factors[i], &htbeta);
backend.block_gemm_subtract(&mut schur, &whitened, &whitened);
}
symmetrize_upper_from_lower(&mut schur);
Ok(schur)
}
fn solve_dense_reduced_system(
schur: &Array2<f64>,
rhs_beta: &Array1<f64>,
options: &ArrowSolveOptions,
metric_weights: Option<&MetricWeights>,
) -> Result<(Array1<f64>, Option<Array2<f64>>, PcgDiagnostics), ArrowSchurError> {
let factor =
cholesky_lower(schur).map_err(|e| ArrowSchurError::SchurFactorFailed { reason: e })?;
let direct = chol_solve_vector(&factor, rhs_beta);
if step_inside_trust_region(direct.view(), options.trust_region.radius, metric_weights) {
return Ok((direct, Some(factor), PcgDiagnostics::default()));
}
let identity = IdentityPreconditioner;
let (delta, diag) = steihaug_dense_system(
schur,
rhs_beta,
&identity,
&ArrowPcgOptions {
max_iterations: options.trust_region.max_iterations,
relative_tolerance: options.trust_region.steihaug_relative_tolerance,
},
&options.trust_region,
metric_weights,
)?;
Ok((delta, Some(factor), diag))
}
fn step_inside_trust_region(
step: ArrayView1<'_, f64>,
radius: f64,
metric_weights: Option<&MetricWeights>,
) -> bool {
!radius.is_finite() || metric_norm(step, metric_weights) <= radius
}
fn schur_matvec<B: BatchedBlockSolver>(
sys: &ArrowSchurSystem,
htt_factors: &[Array2<f64>],
ridge_beta: f64,
x: &Array1<f64>,
out: &mut Array1<f64>,
backend: &B,
) {
let k = sys.k;
{
let x_slice = x.as_slice().expect("x must be contiguous");
let out_slice = out.as_slice_mut().expect("out must be contiguous");
sys.penalty_matvec_add(x_slice, out_slice);
for a in 0..k {
out_slice[a] += ridge_beta * x_slice[a];
}
}
let mut local = Array1::<f64>::zeros(sys.d);
let mut neg_contrib = Array1::<f64>::zeros(k);
for (i, row) in sys.rows.iter().enumerate() {
let di = sys.row_dims[i];
let mut local_i = local.slice_mut(ndarray::s![..di]).to_owned();
local_i.fill(0.0);
sys_htbeta_apply_row(sys, i, row, x.view(), &mut local_i);
let solved = backend.solve_block_vector(&htt_factors[i], &local_i);
neg_contrib.fill(0.0);
sys_htbeta_accumulate_transpose(sys, i, row, solved.view(), &mut neg_contrib);
for a in 0..k {
out[a] -= neg_contrib[a];
}
}
}
#[derive(Clone)]
enum BlockFactor {
Chol {
factor: FaerLlt<f64>,
range: Range<usize>,
},
Scalar {
inv: Array1<f64>,
range: Range<usize>,
},
}
impl std::fmt::Debug for BlockFactor {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
BlockFactor::Chol { range, .. } => {
write!(f, "BlockFactor::Chol {{ range: {:?} }}", range)
}
BlockFactor::Scalar { inv, range } => {
write!(
f,
"BlockFactor::Scalar {{ inv.len: {}, range: {:?} }}",
inv.len(),
range
)
}
}
}
}
#[derive(Debug, Clone)]
pub struct JacobiPreconditioner {
blocks: Vec<BlockFactor>,
}
const BLOCK_JACOBI_MAX_BLOCK: usize = 256;
impl JacobiPreconditioner {
pub fn from_arrow_schur<B: BatchedBlockSolver>(
sys: &ArrowSchurSystem,
htt_factors: &[Array2<f64>],
ridge_beta: f64,
backend: &B,
) -> Result<Self, ArrowSchurError> {
let use_block = !sys.block_offsets.is_empty()
&& sys
.block_offsets
.iter()
.map(|r| r.end.saturating_sub(r.start))
.max()
.unwrap_or(0)
<= BLOCK_JACOBI_MAX_BLOCK;
if use_block {
Self::build_block_jacobi(sys, htt_factors, ridge_beta, backend)
} else {
Self::build_scalar_jacobi(sys, htt_factors, ridge_beta, backend)
}
}
fn build_scalar_jacobi<B: BatchedBlockSolver>(
sys: &ArrowSchurSystem,
htt_factors: &[Array2<f64>],
ridge_beta: f64,
backend: &B,
) -> Result<Self, ArrowSchurError> {
let k = sys.k;
let mut diag = Array1::<f64>::zeros(k);
{
let diag_slice = diag.as_slice_mut().expect("diag must be contiguous");
sys.penalty_diagonal_add(diag_slice);
}
for a in 0..k {
diag[a] += ridge_beta;
}
let mut col = Array1::<f64>::zeros(sys.d);
let mut e_a = Array1::<f64>::zeros(k);
for (i, row) in sys.rows.iter().enumerate() {
let di = sys.row_dims[i];
let mut col_i = col.slice_mut(ndarray::s![..di]).to_owned();
for a in 0..k {
if sys.htbeta_matvec.is_some() || row.htbeta.dim() != (di, k) {
e_a.fill(0.0);
e_a[a] = 1.0;
col_i.fill(0.0);
sys_htbeta_apply_row(sys, i, row, e_a.view(), &mut col_i);
} else {
for c in 0..di {
col_i[c] = row.htbeta[[c, a]];
}
}
let solved = backend.solve_block_vector(&htt_factors[i], &col_i);
let mut acc = 0.0;
for c in 0..di {
acc += col_i[c] * solved[c];
}
diag[a] -= acc;
}
}
let mut blocks = Vec::with_capacity(k);
for a in 0..k {
let v = diag[a];
if !v.is_finite() || v <= 1e-18 {
return Err(ArrowSchurError::PcgFailed {
reason: format!(
"invalid Schur Jacobi diagonal at index {a}: {v}; \
operator regularization is required"
),
});
}
blocks.push(BlockFactor::Scalar {
inv: Array1::from_elem(1, 1.0 / v),
range: a..a + 1,
});
}
Ok(Self { blocks })
}
fn build_block_jacobi<B: BatchedBlockSolver>(
sys: &ArrowSchurSystem,
htt_factors: &[Array2<f64>],
ridge_beta: f64,
backend: &B,
) -> Result<Self, ArrowSchurError> {
let block_offsets = &sys.block_offsets;
let mut blocks = Vec::with_capacity(block_offsets.len());
for (block_idx, range) in block_offsets.iter().enumerate() {
let b = range.end - range.start;
let mut schur_block = Array2::<f64>::zeros((b, b));
sys.penalty_block_add(
BetaBlockId(block_idx),
block_offsets.as_ref(),
&mut schur_block,
);
for bi in 0..b {
schur_block[[bi, bi]] += ridge_beta;
}
for (i, row) in sys.rows.iter().enumerate() {
let di = sys.row_dims[i];
let htbeta_full = sys_htbeta_materialize_row(sys, i, row);
let mut solved_cols = Array2::<f64>::zeros((di, b));
for bj in 0..b {
let gj = range.start + bj;
let solved = backend
.solve_block_vector(&htt_factors[i], &htbeta_full.column(gj).to_owned());
for c in 0..di {
solved_cols[[c, bj]] = solved[c];
}
}
for bi in 0..b {
let gi = range.start + bi;
for bj in 0..b {
let mut acc = 0.0;
for c in 0..di {
acc += htbeta_full[[c, gi]] * solved_cols[[c, bj]];
}
schur_block[[bi, bj]] -= acc;
}
}
}
let factor_opt = {
use faer::Side;
let view = FaerArrayView::new(&schur_block);
FaerLlt::new(view.as_ref(), Side::Lower).ok()
};
if let Some(llt) = factor_opt {
blocks.push(BlockFactor::Chol {
factor: llt,
range: range.clone(),
});
} else {
let mut inv = Array1::<f64>::zeros(b);
for bi in 0..b {
let v = schur_block[[bi, bi]];
if !v.is_finite() || v <= 1e-18 {
return Err(ArrowSchurError::PcgFailed {
reason: format!(
"block Jacobi scalar fallback: non-PD diagonal at \
global index {}: {v}; regularization required",
range.start + bi
),
});
}
inv[bi] = 1.0 / v;
}
blocks.push(BlockFactor::Scalar {
inv,
range: range.clone(),
});
}
}
Ok(Self { blocks })
}
fn apply(&self, r: &Array1<f64>) -> Array1<f64> {
let mut out = Array1::<f64>::zeros(r.len());
for block in &self.blocks {
match block {
BlockFactor::Scalar { inv, range } => {
for (local, gi) in range.clone().enumerate() {
out[gi] = inv[local] * r[gi];
}
}
BlockFactor::Chol { factor, range } => {
let b = range.end - range.start;
let mut rhs = Array1::<f64>::zeros(b);
for (local, gi) in range.clone().enumerate() {
rhs[local] = r[gi];
}
use faer::linalg::solvers::Solve;
let stride = rhs.strides()[0];
let len = rhs.len();
let rhs_mat =
unsafe { faer::MatRef::from_raw_parts(rhs.as_ptr(), len, 1, stride, 0) };
let solved = factor.solve(rhs_mat);
for (local, gi) in range.clone().enumerate() {
out[gi] = solved[(local, 0)];
}
}
}
}
out
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum SchurPreconditionerKind {
Diagonal,
BetaBlockJacobi,
ClusterJacobi,
AdditiveSchwarz { overlap: usize },
}
const PRECOND_ESCALATE_K_THRESHOLD: usize = 100;
#[derive(Clone)]
enum ClusterFactor {
Chol {
cols: Vec<usize>,
factor: FaerLlt<f64>,
},
Scalar {
cols: Vec<usize>,
inv: Vec<f64>,
},
}
impl std::fmt::Debug for ClusterFactor {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
ClusterFactor::Chol { cols, .. } => {
write!(f, "ClusterFactor::Chol {{ cols.len: {} }}", cols.len())
}
ClusterFactor::Scalar { cols, inv } => write!(
f,
"ClusterFactor::Scalar {{ cols.len: {}, inv.len: {} }}",
cols.len(),
inv.len()
),
}
}
}
const CLUSTER_JACOBI_MAX_CLUSTER: usize = 512;
#[derive(Debug, Clone)]
pub struct ClusterJacobiPreconditioner {
clusters: Vec<ClusterFactor>,
}
impl ClusterJacobiPreconditioner {
pub fn from_arrow_schur<B: BatchedBlockSolver>(
sys: &ArrowSchurSystem,
htt_factors: &[Array2<f64>],
ridge_beta: f64,
backend: &B,
) -> Result<Self, ArrowSchurError> {
if sys.block_offsets.is_empty() {
let cols: Vec<usize> = (0..sys.k).collect();
return Self::build_from_column_groups(sys, htt_factors, ridge_beta, backend, &[cols]);
}
let graph = BetaCouplingGraph::build(
&sys.block_offsets,
&sys.rows
.iter()
.map(|r| r.htbeta.clone())
.collect::<Vec<_>>(),
);
let col_groups: Vec<Vec<usize>> = graph
.component_partition()
.iter()
.map(|comp_blocks| {
let mut cols: Vec<usize> = comp_blocks
.iter()
.flat_map(|&b| sys.block_offsets[b].clone())
.collect();
cols.sort_unstable();
cols
})
.collect();
Self::build_from_column_groups(sys, htt_factors, ridge_beta, backend, &col_groups)
}
fn build_from_column_groups<B: BatchedBlockSolver>(
sys: &ArrowSchurSystem,
htt_factors: &[Array2<f64>],
ridge_beta: f64,
backend: &B,
col_groups: &[Vec<usize>],
) -> Result<Self, ArrowSchurError> {
let d = sys.d;
let mut clusters = Vec::with_capacity(col_groups.len());
for cols in col_groups {
let b = cols.len();
if b == 0 {
continue;
}
if b > CLUSTER_JACOBI_MAX_CLUSTER {
let inv = build_schur_scalar_inv(sys, htt_factors, ridge_beta, backend, cols)?;
clusters.push(ClusterFactor::Scalar {
cols: cols.clone(),
inv,
});
continue;
}
let mut s_block = Array2::<f64>::zeros((b, b));
sys.penalty_subblock_add(cols, &mut s_block);
for bi in 0..b {
s_block[[bi, bi]] += ridge_beta;
}
let mut col_vec = Array1::<f64>::zeros(d);
let mut solved_cols = Array2::<f64>::zeros((d, b));
for (row_idx, row) in sys.rows.iter().enumerate() {
for bj in 0..b {
let gj = cols[bj];
for c in 0..d {
col_vec[c] = row.htbeta[[c, gj]];
}
let solved = backend.solve_block_vector(&htt_factors[row_idx], &col_vec);
for c in 0..d {
solved_cols[[c, bj]] = solved[c];
}
}
for bi in 0..b {
let gi = cols[bi];
for bj in 0..b {
let mut acc = 0.0;
for c in 0..d {
acc += row.htbeta[[c, gi]] * solved_cols[[c, bj]];
}
s_block[[bi, bj]] -= acc;
}
}
}
symmetrize_upper_from_lower(&mut s_block);
let factor_opt = {
use faer::Side;
let view = FaerArrayView::new(&s_block);
FaerLlt::new(view.as_ref(), Side::Lower).ok()
};
if let Some(llt) = factor_opt {
clusters.push(ClusterFactor::Chol {
cols: cols.clone(),
factor: llt,
});
} else {
let inv = build_schur_scalar_inv(sys, htt_factors, ridge_beta, backend, cols)?;
clusters.push(ClusterFactor::Scalar {
cols: cols.clone(),
inv,
});
}
}
Ok(Self { clusters })
}
fn apply(&self, r: &Array1<f64>) -> Array1<f64> {
let mut out = Array1::<f64>::zeros(r.len());
for cluster in &self.clusters {
apply_cluster_non_overlapping(cluster, r, &mut out);
}
out
}
}
#[derive(Debug, Clone)]
pub struct AdditiveSchwarzPreconditioner {
clusters: Vec<ClusterFactor>,
weights: Vec<f64>,
}
impl AdditiveSchwarzPreconditioner {
pub fn from_arrow_schur<B: BatchedBlockSolver>(
sys: &ArrowSchurSystem,
htt_factors: &[Array2<f64>],
ridge_beta: f64,
backend: &B,
overlap: usize,
) -> Result<Self, ArrowSchurError> {
if sys.block_offsets.is_empty() {
let cols: Vec<usize> = (0..sys.k).collect();
let inner = ClusterJacobiPreconditioner::build_from_column_groups(
sys,
htt_factors,
ridge_beta,
backend,
&[cols],
)?;
return Ok(Self {
clusters: inner.clusters,
weights: vec![1.0f64; sys.k],
});
}
let graph = BetaCouplingGraph::build(
&sys.block_offsets,
&sys.rows
.iter()
.map(|r| r.htbeta.clone())
.collect::<Vec<_>>(),
);
let col_groups: Vec<Vec<usize>> = graph
.component_partition()
.iter()
.map(|seed| {
let mut current = seed.clone();
for _ in 0..overlap {
current = graph.expand_one_hop(¤t);
}
let mut cols: Vec<usize> = current
.iter()
.flat_map(|&b| sys.block_offsets[b].clone())
.collect();
cols.sort_unstable();
cols.dedup();
cols
})
.collect();
let mut counts = vec![0u32; sys.k];
for cols in &col_groups {
for &gi in cols {
counts[gi] += 1;
}
}
let weights: Vec<f64> = counts
.iter()
.map(|&c| if c == 0 { 1.0 } else { 1.0 / c as f64 })
.collect();
let inner = ClusterJacobiPreconditioner::build_from_column_groups(
sys,
htt_factors,
ridge_beta,
backend,
&col_groups,
)?;
Ok(Self {
clusters: inner.clusters,
weights,
})
}
fn apply(&self, r: &Array1<f64>) -> Array1<f64> {
let mut out = Array1::<f64>::zeros(r.len());
for cluster in &self.clusters {
apply_cluster_overlapping(cluster, r, &mut out, &self.weights);
}
out
}
}
fn apply_cluster_non_overlapping(cluster: &ClusterFactor, r: &Array1<f64>, out: &mut Array1<f64>) {
match cluster {
ClusterFactor::Scalar { cols, inv } => {
for (local, &gi) in cols.iter().enumerate() {
out[gi] = inv[local] * r[gi];
}
}
ClusterFactor::Chol { cols, factor } => {
let b = cols.len();
let mut rhs = Array1::<f64>::zeros(b);
for (local, &gi) in cols.iter().enumerate() {
rhs[local] = r[gi];
}
use faer::linalg::solvers::Solve;
let stride = rhs.strides()[0];
let len = rhs.len();
let rhs_mat = unsafe { faer::MatRef::from_raw_parts(rhs.as_ptr(), len, 1, stride, 0) };
let solved = factor.solve(rhs_mat);
for (local, &gi) in cols.iter().enumerate() {
out[gi] = solved[(local, 0)];
}
}
}
}
fn apply_cluster_overlapping(
cluster: &ClusterFactor,
r: &Array1<f64>,
out: &mut Array1<f64>,
weights: &[f64],
) {
match cluster {
ClusterFactor::Scalar { cols, inv } => {
for (local, &gi) in cols.iter().enumerate() {
out[gi] += weights[gi] * inv[local] * r[gi];
}
}
ClusterFactor::Chol { cols, factor } => {
let b = cols.len();
let mut rhs = Array1::<f64>::zeros(b);
for (local, &gi) in cols.iter().enumerate() {
rhs[local] = r[gi];
}
use faer::linalg::solvers::Solve;
let stride = rhs.strides()[0];
let len = rhs.len();
let rhs_mat = unsafe { faer::MatRef::from_raw_parts(rhs.as_ptr(), len, 1, stride, 0) };
let solved = factor.solve(rhs_mat);
for (local, &gi) in cols.iter().enumerate() {
out[gi] += weights[gi] * solved[(local, 0)];
}
}
}
}
fn build_schur_scalar_inv<B: BatchedBlockSolver>(
sys: &ArrowSchurSystem,
htt_factors: &[Array2<f64>],
ridge_beta: f64,
backend: &B,
cols: &[usize],
) -> Result<Vec<f64>, ArrowSchurError> {
let d = sys.d;
let mut result = Vec::with_capacity(cols.len());
let mut col_vec = Array1::<f64>::zeros(d);
let mut full_diag = Array1::<f64>::zeros(sys.k);
{
let fd_slice = full_diag.as_slice_mut().expect("full_diag contiguous");
sys.penalty_diagonal_add(fd_slice);
}
for &gi in cols {
let mut s = full_diag[gi] + ridge_beta;
for (row_idx, row) in sys.rows.iter().enumerate() {
for c in 0..d {
col_vec[c] = row.htbeta[[c, gi]];
}
let solved = backend.solve_block_vector(&htt_factors[row_idx], &col_vec);
let mut acc = 0.0;
for c in 0..d {
acc += col_vec[c] * solved[c];
}
s -= acc;
}
if !s.is_finite() || s <= 1e-18 {
return Err(ArrowSchurError::PcgFailed {
reason: format!(
"cluster Schur scalar fallback: non-PD diagonal at index {gi}: {s}"
),
});
}
result.push(1.0 / s);
}
Ok(result)
}
fn steihaug_pcg_auto<B: BatchedBlockSolver>(
sys: &ArrowSchurSystem,
htt_factors: &[Array2<f64>],
ridge_beta: f64,
rhs: &Array1<f64>,
pcg: &ArrowPcgOptions,
trust: &ArrowTrustRegionOptions,
backend: &B,
gpu_matvec: Option<&GpuSchurMatvec>,
metric_weights: Option<&MetricWeights>,
) -> Result<(Array1<f64>, PcgDiagnostics), ArrowSchurError> {
let jacobi = JacobiPreconditioner::from_arrow_schur(sys, htt_factors, ridge_beta, backend)?;
let (x0, diag0) = run_pcg_with_preconditioner(
sys,
htt_factors,
ridge_beta,
rhs,
|r| jacobi.apply(r),
pcg,
trust,
backend,
gpu_matvec,
metric_weights,
)?;
if sys.k <= PRECOND_ESCALATE_K_THRESHOLD || diag0.stopping_reason != PcgStopReason::MaxIter {
return Ok((x0, diag0));
}
let cluster =
ClusterJacobiPreconditioner::from_arrow_schur(sys, htt_factors, ridge_beta, backend)?;
let (x1, diag1) = run_pcg_with_preconditioner(
sys,
htt_factors,
ridge_beta,
rhs,
|r| cluster.apply(r),
pcg,
trust,
backend,
gpu_matvec,
metric_weights,
)?;
if diag1.stopping_reason != PcgStopReason::MaxIter {
return Ok((x1, diag1));
}
let schwarz =
AdditiveSchwarzPreconditioner::from_arrow_schur(sys, htt_factors, ridge_beta, backend, 1)?;
run_pcg_with_preconditioner(
sys,
htt_factors,
ridge_beta,
rhs,
|r| schwarz.apply(r),
pcg,
trust,
backend,
gpu_matvec,
metric_weights,
)
}
fn run_pcg_with_preconditioner<ApplyPrec, B: BatchedBlockSolver>(
sys: &ArrowSchurSystem,
htt_factors: &[Array2<f64>],
ridge_beta: f64,
rhs: &Array1<f64>,
apply_prec: ApplyPrec,
pcg: &ArrowPcgOptions,
trust: &ArrowTrustRegionOptions,
backend: &B,
gpu_matvec: Option<&GpuSchurMatvec>,
metric_weights: Option<&MetricWeights>,
) -> Result<(Array1<f64>, PcgDiagnostics), ArrowSchurError>
where
ApplyPrec: FnMut(&Array1<f64>) -> Array1<f64>,
{
let max_iters = pcg.max_iterations.min(trust.max_iterations);
let tol = pcg
.relative_tolerance
.max(trust.steihaug_relative_tolerance);
if let Some(gpu_mv) = gpu_matvec {
let gpu_mv = Arc::clone(gpu_mv);
steihaug_cg(
rhs,
move |p, out| gpu_mv(p, out),
apply_prec,
max_iters,
tol,
trust.radius,
metric_weights,
)
} else {
steihaug_cg(
rhs,
|p, out| schur_matvec(sys, htt_factors, ridge_beta, p, out, backend),
apply_prec,
max_iters,
tol,
trust.radius,
metric_weights,
)
}
}
#[derive(Debug, Clone, Copy)]
struct IdentityPreconditioner;
impl IdentityPreconditioner {
fn apply(&self, r: &Array1<f64>) -> Array1<f64> {
r.clone()
}
}
fn steihaug_dense_system(
schur: &Array2<f64>,
rhs: &Array1<f64>,
preconditioner: &IdentityPreconditioner,
pcg: &ArrowPcgOptions,
trust: &ArrowTrustRegionOptions,
metric_weights: Option<&MetricWeights>,
) -> Result<(Array1<f64>, PcgDiagnostics), ArrowSchurError> {
steihaug_cg(
rhs,
|p, out| dense_matvec(schur, p, out),
|r| preconditioner.apply(r),
pcg.max_iterations,
pcg.relative_tolerance,
trust.radius,
metric_weights,
)
}
fn steihaug_cg<MatVec, ApplyPrec>(
rhs: &Array1<f64>,
mut matvec: MatVec,
mut apply_preconditioner: ApplyPrec,
max_iterations: usize,
relative_tolerance: f64,
trust_radius: f64,
metric_weights: Option<&MetricWeights>,
) -> Result<(Array1<f64>, PcgDiagnostics), ArrowSchurError>
where
MatVec: FnMut(&Array1<f64>, &mut Array1<f64>),
ApplyPrec: FnMut(&Array1<f64>) -> Array1<f64>,
{
let n = rhs.len();
if let Some(weights) = metric_weights {
assert_eq!(
weights.len(),
n,
"Steihaug-CG metric weight length must match solve dimension"
);
}
let radius = if trust_radius.is_finite() && trust_radius > 0.0 {
trust_radius
} else {
f64::INFINITY
};
let rhs_norm = metric_norm(rhs.view(), metric_weights);
if rhs_norm == 0.0 {
return Ok((Array1::<f64>::zeros(n), PcgDiagnostics::default()));
}
let tol = relative_tolerance.max(0.0) * rhs_norm;
let mut x = Array1::<f64>::zeros(n);
let mut r = rhs.clone();
let mut z = apply_preconditioner(&r);
let mut diag = PcgDiagnostics {
precond_apply_calls: 1,
..PcgDiagnostics::default()
};
let mut p = z.clone();
let mut rz = metric_dot(&r, &z, metric_weights);
if rz <= 0.0 || !rz.is_finite() {
if radius.is_finite() {
diag.final_relative_residual = metric_norm(r.view(), metric_weights) / rhs_norm;
diag.stopping_reason = PcgStopReason::TrustRegion;
return Ok((step_to_trust_boundary(&x, &r, radius, metric_weights), diag));
}
return Err(ArrowSchurError::PcgFailed {
reason: "non-positive preconditioned residual in Schur PCG".to_string(),
});
}
if metric_norm(r.view(), metric_weights) <= tol {
diag.final_relative_residual = 0.0;
diag.stopping_reason = PcgStopReason::Converged;
return Ok((x, diag));
}
let mut ap = Array1::<f64>::zeros(n);
let mut candidate = Array1::<f64>::zeros(n);
for _ in 0..max_iterations {
matvec(&p, &mut ap);
diag.matvec_calls += 1;
diag.iterations += 1;
let pap = metric_dot(&p, &ap, metric_weights);
if pap <= 0.0 || !pap.is_finite() {
if radius.is_finite() {
diag.final_relative_residual = metric_norm(r.view(), metric_weights) / rhs_norm;
diag.stopping_reason = PcgStopReason::TrustRegion;
return Ok((step_to_trust_boundary(&x, &p, radius, metric_weights), diag));
}
return Err(ArrowSchurError::PcgFailed {
reason: "negative curvature in unbounded Schur PCG".to_string(),
});
}
let alpha = rz / pap;
for i in 0..n {
candidate[i] = x[i] + alpha * p[i];
}
if radius.is_finite() && metric_norm(candidate.view(), metric_weights) >= radius {
diag.final_relative_residual = metric_norm(r.view(), metric_weights) / rhs_norm;
diag.stopping_reason = PcgStopReason::TrustRegion;
return Ok((step_to_trust_boundary(&x, &p, radius, metric_weights), diag));
}
x.assign(&candidate);
for i in 0..n {
r[i] -= alpha * ap[i];
}
if metric_norm(r.view(), metric_weights) <= tol {
diag.final_relative_residual = metric_norm(r.view(), metric_weights) / rhs_norm;
diag.stopping_reason = PcgStopReason::Converged;
return Ok((x, diag));
}
z = apply_preconditioner(&r);
diag.precond_apply_calls += 1;
let rz_next = metric_dot(&r, &z, metric_weights);
if rz_next <= 0.0 || !rz_next.is_finite() {
return Err(ArrowSchurError::PcgFailed {
reason: "non-positive or non-finite PCG residual".to_string(),
});
}
let beta = rz_next / rz;
for i in 0..n {
p[i] = z[i] + beta * p[i];
}
rz = rz_next;
}
diag.final_relative_residual = metric_norm(r.view(), metric_weights) / rhs_norm;
diag.stopping_reason = PcgStopReason::MaxIter;
Ok((x, diag))
}
fn step_to_trust_boundary(
x: &Array1<f64>,
p: &Array1<f64>,
radius: f64,
metric_weights: Option<&MetricWeights>,
) -> Array1<f64> {
let pp = metric_dot(p, p, metric_weights);
if pp == 0.0 {
return x.clone();
}
let xp = metric_dot(x, p, metric_weights);
let xx = metric_dot(x, x, metric_weights);
let disc = (xp * xp + pp * (radius * radius - xx)).max(0.0);
let tau = (-xp + disc.sqrt()) / pp;
let mut out = x.clone();
for i in 0..out.len() {
out[i] += tau * p[i];
}
out
}
fn dense_matvec(a: &Array2<f64>, x: &Array1<f64>, out: &mut Array1<f64>) {
let n = a.nrows();
for i in 0..n {
let mut acc = 0.0;
for j in 0..n {
acc += a[[i, j]] * x[j];
}
out[i] = acc;
}
}
fn dot(a: &Array1<f64>, b: &Array1<f64>) -> f64 {
let mut acc = 0.0;
for i in 0..a.len() {
acc += a[i] * b[i];
}
acc
}
fn metric_dot(a: &Array1<f64>, b: &Array1<f64>, metric_weights: Option<&MetricWeights>) -> f64 {
assert_eq!(a.len(), b.len());
match metric_weights {
Some(weights) => {
assert_eq!(weights.len(), a.len());
let mut acc = 0.0;
for i in 0..a.len() {
acc += weights[i] * a[i] * b[i];
}
acc
}
None => dot(a, b),
}
}
fn metric_norm(v: ArrayView1<'_, f64>, metric_weights: Option<&MetricWeights>) -> f64 {
let mut acc = 0.0;
match metric_weights {
Some(weights) => {
assert_eq!(weights.len(), v.len());
for i in 0..v.len() {
acc += weights[i] * v[i] * v[i];
}
}
None => {
for x in v.iter() {
acc += x * x;
}
}
}
acc.sqrt()
}
fn symmetrize_upper_from_lower(a: &mut Array2<f64>) {
let n = a.nrows().min(a.ncols());
for i in 0..n {
for j in 0..i {
let v = 0.5 * (a[[i, j]] + a[[j, i]]);
a[[i, j]] = v;
a[[j, i]] = v;
}
}
}
#[derive(Debug, Clone)]
pub enum ArrowSchurError {
PerRowFactorFailed { row: usize, reason: String },
PerRowFactorIllConditioned { row: usize, kappa_estimate: f64 },
SchurFactorFailed { reason: String },
PcgFailed { reason: String },
AdaptiveCorrectionFailed { reason: String },
}
impl std::fmt::Display for ArrowSchurError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
ArrowSchurError::PerRowFactorFailed { row, reason } => write!(
f,
"arrow-Schur: per-row H_tt^({row}) Cholesky failed: {reason}"
),
ArrowSchurError::PerRowFactorIllConditioned {
row,
kappa_estimate,
} => write!(
f,
"arrow-Schur: per-row H_tt^({row}) Cholesky succeeded but is \
ill-conditioned (kappa_estimate={kappa_estimate:e}); Schur \
reduction would be numerically contaminated"
),
ArrowSchurError::SchurFactorFailed { reason } => {
write!(f, "arrow-Schur: Schur complement Cholesky failed: {reason}")
}
ArrowSchurError::PcgFailed { reason } => {
write!(f, "arrow-Schur: Schur PCG failed: {reason}")
}
ArrowSchurError::AdaptiveCorrectionFailed { reason } => {
write!(
f,
"arrow-Schur: adaptive proximal correction failed: {reason}"
)
}
}
}
}
impl std::error::Error for ArrowSchurError {}
fn cholesky_lower(a: &Array2<f64>) -> Result<Array2<f64>, String> {
let n = a.nrows();
if a.ncols() != n {
return Err(format!("cholesky_lower: non-square {}×{}", n, a.ncols()));
}
if let Some((idx, _)) = a.iter().enumerate().find(|(_, v)| !v.is_finite()) {
return Err(format!(
"cholesky_lower: non-finite entry at linear index {idx}"
));
}
let mut l = Array2::<f64>::zeros((n, n));
for i in 0..n {
for j in 0..=i {
let mut sum = a[[i, j]];
for kk in 0..j {
sum -= l[[i, kk]] * l[[j, kk]];
}
if i == j {
if !sum.is_finite() || sum <= 0.0 {
return Err(format!(
"non-PD pivot {sum} at index {i} (matrix is not positive definite)"
));
}
l[[i, j]] = sum.sqrt();
} else {
l[[i, j]] = sum / l[[j, j]];
}
}
}
Ok(l)
}
fn chol_solve_vector(l: &Array2<f64>, b: &Array1<f64>) -> Array1<f64> {
let n = l.nrows();
let mut y = Array1::<f64>::zeros(n);
for i in 0..n {
let mut sum = b[i];
for kk in 0..i {
sum -= l[[i, kk]] * y[kk];
}
y[i] = sum / l[[i, i]];
}
let mut x = Array1::<f64>::zeros(n);
for i in (0..n).rev() {
let mut sum = y[i];
for kk in (i + 1)..n {
sum -= l[[kk, i]] * x[kk];
}
x[i] = sum / l[[i, i]];
}
x
}
fn chol_solve_matrix(l: &Array2<f64>, b: &Array2<f64>) -> Array2<f64> {
let n = l.nrows();
let m = b.ncols();
let mut out = Array2::<f64>::zeros((n, m));
let mut col = Array1::<f64>::zeros(n);
for cidx in 0..m {
for r in 0..n {
col[r] = b[[r, cidx]];
}
let x = chol_solve_vector(l, &col);
for r in 0..n {
out[[r, cidx]] = x[r];
}
}
out
}
fn lower_triangular_solve_matrix(l: &Array2<f64>, b: &Array2<f64>) -> Array2<f64> {
let n = l.nrows();
let m = b.ncols();
let mut out = Array2::<f64>::zeros((n, m));
for cidx in 0..m {
for i in 0..n {
let mut sum = b[[i, cidx]];
for kk in 0..i {
sum -= l[[i, kk]] * out[[kk, cidx]];
}
out[[i, cidx]] = sum / l[[i, i]];
}
}
out
}
#[cfg(test)]
mod tests {
use super::*;
use ndarray::array;
#[test]
fn arrow_schur_matches_dense_reference_2x2() {
let n = 2;
let d = 2;
let k = 3;
let mut sys = ArrowSchurSystem::new(n, d, k);
sys.rows[0].htt = array![[2.0_f64, 0.1], [0.1, 3.0]];
sys.rows[0].htbeta = array![[1.0_f64, 0.0, 0.5], [0.2, 1.0, 0.0]];
sys.rows[0].gt = array![0.3_f64, -0.2];
sys.rows[1].htt = array![[1.5_f64, -0.1], [-0.1, 2.0]];
sys.rows[1].htbeta = array![[0.1_f64, 0.5, 0.0], [0.0, 0.3, 1.0]];
sys.rows[1].gt = array![-0.1_f64, 0.4];
sys.hbb = array![[4.0_f64, 0.2, 0.0], [0.2, 5.0, 0.1], [0.0, 0.1, 6.0],];
sys.gb = array![0.5_f64, -0.3, 0.2];
let (delta_t, delta_beta, _diag) = sys.solve(0.0, 0.0).expect("arrow-schur solve");
let streaming_options = ArrowSolveOptions::direct().with_streaming_chunk_size(Some(1));
let (delta_t_stream, delta_beta_stream, _diag_stream) = sys
.solve_with_options(0.0, 0.0, &streaming_options)
.expect("streaming arrow-schur solve");
assert_eq!(delta_beta, delta_beta_stream);
assert_eq!(delta_t, delta_t_stream);
let total = k + n * d;
let mut hjoint = Array2::<f64>::zeros((total, total));
let mut gjoint = Array1::<f64>::zeros(total);
for a in 0..k {
for b in 0..k {
hjoint[[a, b]] = sys.hbb[[a, b]];
}
gjoint[a] = sys.gb[a];
}
for i in 0..n {
let toff = k + i * d;
for a in 0..d {
for b in 0..d {
hjoint[[toff + a, toff + b]] = sys.rows[i].htt[[a, b]];
}
gjoint[toff + a] = sys.rows[i].gt[a];
for a2 in 0..k {
hjoint[[toff + a, a2]] = sys.rows[i].htbeta[[a, a2]];
hjoint[[a2, toff + a]] = sys.rows[i].htbeta[[a, a2]];
}
}
}
let lj = cholesky_lower(&hjoint).expect("dense ref PD");
let neg_g = gjoint.mapv(|v| -v);
let xref = chol_solve_vector(&lj, &neg_g);
for a in 0..k {
assert!(
(xref[a] - delta_beta[a]).abs() < 1e-10,
"β[{a}] mismatch: dense {} vs arrow {}",
xref[a],
delta_beta[a]
);
}
for i in 0..n {
for a in 0..d {
let dense = xref[k + i * d + a];
let arrow = delta_t[i * d + a];
assert!(
(dense - arrow).abs() < 1e-10,
"t[{i},{a}] mismatch: dense {dense} vs arrow {arrow}"
);
}
}
}
fn quartic_counterexample_value(t: f64) -> f64 {
0.25 * t.powi(4) - t * t + 2.0 * t
}
fn quartic_counterexample_system(t: f64) -> ArrowSchurSystem {
let mut sys = ArrowSchurSystem::new(1, 1, 0);
sys.rows[0].gt = array![t.powi(3) - 2.0 * t + 2.0];
sys.rows[0].htt = array![[3.0 * t * t - 2.0]];
sys
}
#[test]
fn proximal_correction_breaks_scalar_newton_cycle() {
let options = ArrowSolveOptions::direct();
let correction = ArrowProximalCorrectionOptions {
initial_ridge: 1e-8,
ridge_growth: 10.0,
max_attempts: 16,
armijo_c1: 1e-4,
gradient_tolerance: 1e-12,
};
let mut t = 0.0_f64;
let mut previous_value = quartic_counterexample_value(t);
for _ in 0..32 {
let sys = quartic_counterexample_system(t);
let accepted = solve_arrow_newton_step_with_proximal_correction(
&sys,
0.0,
0.0,
previous_value,
&options,
&correction,
|delta_t, _delta_beta| quartic_counterexample_value(t + delta_t[0]),
)
.expect("proximal correction should accept a descent step");
assert!(
accepted.trial_objective_value <= previous_value,
"accepted step must not increase the objective"
);
t += accepted.delta_t[0];
previous_value = accepted.trial_objective_value;
}
let final_grad = t.powi(3) - 2.0 * t + 2.0;
assert!(
final_grad.abs() < 1e-7,
"corrected iteration should reach the scalar critical point; t={t}, g={final_grad}"
);
}
#[test]
fn factor_one_row_rejects_barely_pd_block() {
let d = 2;
let k = 2;
let mut row = ArrowRowBlock::new(d, k);
row.htt = array![[1.0_f64, 1.0], [1.0, 1.0 + 1e-14]];
row.htbeta = array![[1.0_f64, 0.0], [0.0, 1.0]];
row.gt = array![0.0_f64, 0.0];
let err = factor_one_row(&row, 0.0, d, 0)
.expect_err("barely-PD H_tt must be rejected by the condition check");
match err {
ArrowSchurError::PerRowFactorIllConditioned {
row: r,
kappa_estimate,
} => {
assert_eq!(r, 0);
assert!(
kappa_estimate > 1e10,
"kappa estimate should reflect the barely-PD block; got {kappa_estimate:e}"
);
}
other => panic!("expected PerRowFactorIllConditioned, got {other:?}"),
}
let mut row_ok = ArrowRowBlock::new(d, k);
row_ok.htt = array![[2.0_f64, 0.1], [0.1, 3.0]];
row_ok.htbeta = array![[1.0_f64, 0.0], [0.0, 1.0]];
row_ok.gt = array![0.0_f64, 0.0];
factor_one_row(&row_ok, 0.0, d, 0)
.expect("well-conditioned block must still factor at ridge_t=0");
}
#[test]
fn lm_escalation_recovers_from_ill_conditioned_row() {
let n = 1;
let d = 2;
let k = 2;
let mut sys = ArrowSchurSystem::new(n, d, k);
sys.rows[0].htt = array![[1.0_f64, 1.0], [1.0, 1.0 + 1e-14]];
sys.rows[0].htbeta = array![[1.0_f64, 0.0], [0.0, 1.0]];
sys.rows[0].gt = array![0.1_f64, -0.2];
sys.hbb = array![[4.0_f64, 0.2], [0.2, 5.0]];
sys.gb = array![0.3_f64, -0.1];
let direct = factor_one_row(&sys.rows[0], 0.0, d, 0);
assert!(matches!(
direct,
Err(ArrowSchurError::PerRowFactorIllConditioned { .. })
));
let options = ArrowSolveOptions::direct();
let (delta_t, delta_beta, _diag) = solve_with_lm_escalation_inner(&sys, 0.0, 0.0, &options)
.expect("LM escalation must recover from PerRowFactorIllConditioned");
for v in delta_t.iter().chain(delta_beta.iter()) {
assert!(v.is_finite(), "recovered step must be finite: {v}");
}
}
}