use ndarray::{Array1, Array2, ArrayView1};
use std::sync::Arc;
use crate::solver::persistent_warm_start::StableHasher;
use crate::terms::analytic_penalties::{AnalyticPenaltyKind, AnalyticPenaltyRegistry, PenaltyTier};
use crate::terms::latent_coord::{LatentCoordValues, LatentManifold};
const DIRECT_SOLVE_MAX_K: usize = 2_000;
const DEFAULT_PCG_MAX_ITERATIONS: usize = 200;
const DEFAULT_PCG_RELATIVE_TOLERANCE: f64 = 1e-4;
const DEFAULT_TRUST_REGION_RADIUS: f64 = f64::INFINITY;
pub const DEFAULT_PROXIMAL_INITIAL_RIDGE: f64 = 1e-8;
pub const DEFAULT_PROXIMAL_RIDGE_GROWTH: f64 = 10.0;
pub const DEFAULT_PROXIMAL_MAX_ATTEMPTS: usize = 16;
const DEFAULT_ARMIJO_C1: f64 = 1e-4;
const DEFAULT_GRADIENT_TOLERANCE: f64 = 1e-10;
const EUCLIDEAN_MANIFOLD_MODE_FINGERPRINT: u64 = 0;
const ARROW_FACTOR_CACHE_HTBETA_BUDGET_BYTES: usize = 256 * 1024 * 1024;
pub type SharedBetaMatvec =
Arc<dyn for<'a> Fn(ArrayView1<'a, f64>, &mut Array1<f64>) + Send + Sync>;
pub type RowHtbetaMatvec =
Arc<dyn for<'a> Fn(usize, ArrayView1<'a, f64>, &mut Array1<f64>) + Send + Sync>;
pub type StreamingArrowRowBuilder =
Arc<dyn Fn(usize) -> Result<ArrowRowBlock, ArrowSchurError> + Send + Sync>;
type MetricWeights = [f64];
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ArrowSolverMode {
Direct,
SqrtBA,
InexactPCG,
}
impl ArrowSolverMode {
pub const fn automatic(k: usize) -> Self {
if k <= DIRECT_SOLVE_MAX_K {
Self::Direct
} else {
Self::InexactPCG
}
}
pub const fn automatic_for_single_precision(k: usize) -> Self {
if k <= DIRECT_SOLVE_MAX_K {
Self::SqrtBA
} else {
Self::InexactPCG
}
}
}
#[derive(Debug, Clone)]
pub struct ArrowPcgOptions {
pub max_iterations: usize,
pub relative_tolerance: f64,
}
impl Default for ArrowPcgOptions {
fn default() -> Self {
Self {
max_iterations: DEFAULT_PCG_MAX_ITERATIONS,
relative_tolerance: DEFAULT_PCG_RELATIVE_TOLERANCE,
}
}
}
#[derive(Debug, Clone)]
pub struct ArrowTrustRegionOptions {
pub radius: f64,
pub steihaug_relative_tolerance: f64,
pub max_iterations: usize,
}
impl Default for ArrowTrustRegionOptions {
fn default() -> Self {
Self {
radius: DEFAULT_TRUST_REGION_RADIUS,
steihaug_relative_tolerance: DEFAULT_PCG_RELATIVE_TOLERANCE,
max_iterations: DEFAULT_PCG_MAX_ITERATIONS,
}
}
}
#[derive(Debug, Clone)]
pub struct ArrowSolveOptions {
pub mode: ArrowSolverMode,
pub pcg: ArrowPcgOptions,
pub trust_region: ArrowTrustRegionOptions,
pub streaming_chunk_size: Option<usize>,
pub riemannian_trust_region: bool,
}
#[derive(Debug, Clone)]
pub struct ArrowProximalCorrectionOptions {
pub initial_ridge: f64,
pub ridge_growth: f64,
pub max_attempts: usize,
pub armijo_c1: f64,
pub gradient_tolerance: f64,
}
impl Default for ArrowProximalCorrectionOptions {
fn default() -> Self {
Self {
initial_ridge: DEFAULT_PROXIMAL_INITIAL_RIDGE,
ridge_growth: DEFAULT_PROXIMAL_RIDGE_GROWTH,
max_attempts: DEFAULT_PROXIMAL_MAX_ATTEMPTS,
armijo_c1: DEFAULT_ARMIJO_C1,
gradient_tolerance: DEFAULT_GRADIENT_TOLERANCE,
}
}
}
#[derive(Debug, Clone)]
pub struct ArrowAcceptedProximalStep {
pub delta_t: Array1<f64>,
pub delta_beta: Array1<f64>,
pub ridge_t: f64,
pub ridge_beta: f64,
pub proximal_ridge: f64,
pub objective_value: f64,
pub trial_objective_value: f64,
pub gradient_dot_step: f64,
pub attempts: usize,
}
impl ArrowSolveOptions {
pub fn automatic(k: usize) -> Self {
Self {
mode: ArrowSolverMode::automatic(k),
pcg: ArrowPcgOptions::default(),
trust_region: ArrowTrustRegionOptions::default(),
streaming_chunk_size: None,
riemannian_trust_region: false,
}
}
pub fn direct() -> Self {
Self {
mode: ArrowSolverMode::Direct,
pcg: ArrowPcgOptions::default(),
trust_region: ArrowTrustRegionOptions::default(),
streaming_chunk_size: None,
riemannian_trust_region: false,
}
}
pub fn sqrt_ba() -> Self {
Self {
mode: ArrowSolverMode::SqrtBA,
pcg: ArrowPcgOptions::default(),
trust_region: ArrowTrustRegionOptions::default(),
streaming_chunk_size: None,
riemannian_trust_region: false,
}
}
pub fn inexact_pcg() -> Self {
Self {
mode: ArrowSolverMode::InexactPCG,
pcg: ArrowPcgOptions::default(),
trust_region: ArrowTrustRegionOptions::default(),
streaming_chunk_size: None,
riemannian_trust_region: false,
}
}
pub fn with_streaming_chunk_size(mut self, chunk_size: Option<usize>) -> Self {
self.streaming_chunk_size = chunk_size.filter(|&chunk| chunk > 0);
self
}
}
pub trait BatchedBlockSolver {
fn factor_blocks(
&self,
rows: &[ArrowRowBlock],
ridge_t: f64,
d: usize,
) -> Result<Vec<Array2<f64>>, ArrowSchurError>;
fn solve_block_vector(&self, factor: &Array2<f64>, rhs: &Array1<f64>) -> Array1<f64>;
fn solve_block_matrix(&self, factor: &Array2<f64>, rhs: &Array2<f64>) -> Array2<f64>;
fn sqrt_solve_block_matrix(&self, factor: &Array2<f64>, rhs: &Array2<f64>) -> Array2<f64>;
fn block_gemm_subtract(&self, schur: &mut Array2<f64>, left: &Array2<f64>, right: &Array2<f64>);
}
#[derive(Debug, Clone, Copy, Default)]
pub struct CpuBatchedBlockSolver;
impl BatchedBlockSolver for CpuBatchedBlockSolver {
fn factor_blocks(
&self,
rows: &[ArrowRowBlock],
ridge_t: f64,
d: usize,
) -> Result<Vec<Array2<f64>>, ArrowSchurError> {
let mut out = Vec::with_capacity(rows.len());
for (row_idx, row) in rows.iter().enumerate() {
out.push(factor_one_row(row, ridge_t, d, row_idx)?);
}
Ok(out)
}
fn solve_block_vector(&self, factor: &Array2<f64>, rhs: &Array1<f64>) -> Array1<f64> {
chol_solve_vector(factor, rhs)
}
fn solve_block_matrix(&self, factor: &Array2<f64>, rhs: &Array2<f64>) -> Array2<f64> {
chol_solve_matrix(factor, rhs)
}
fn sqrt_solve_block_matrix(&self, factor: &Array2<f64>, rhs: &Array2<f64>) -> Array2<f64> {
lower_triangular_solve_matrix(factor, rhs)
}
fn block_gemm_subtract(
&self,
schur: &mut Array2<f64>,
left: &Array2<f64>,
right: &Array2<f64>,
) {
let k = schur.nrows();
let d = left.nrows();
assert_eq!(left.ncols(), k);
assert_eq!(right.ncols(), k);
assert_eq!(schur.ncols(), k);
for c in 0..d {
for a in 0..k {
let lca = left[[c, a]];
if lca == 0.0 {
continue;
}
for b in 0..k {
schur[[a, b]] -= lca * right[[c, b]];
}
}
}
}
}
fn factor_one_row(
row: &ArrowRowBlock,
ridge_t: f64,
d: usize,
row_idx: usize,
) -> Result<Array2<f64>, ArrowSchurError> {
if row.htt.dim() != (d, d) {
return Err(ArrowSchurError::PerRowFactorFailed {
row: row_idx,
reason: format!(
"row {row_idx} H_tt shape {:?} does not match per_point_hessian_block dimension ({d}, {d})",
row.htt.dim()
),
});
}
if row.gt.len() != d {
return Err(ArrowSchurError::PerRowFactorFailed {
row: row_idx,
reason: format!(
"row {row_idx} g_t length {} does not match latent dimension {d}",
row.gt.len()
),
});
}
let mut block = row.htt.clone();
for a in 0..d {
block[[a, a]] += ridge_t;
}
let factor = cholesky_lower(&block).map_err(|e| ArrowSchurError::PerRowFactorFailed {
row: row_idx,
reason: format!(
"row {row_idx} H_tt was non-PD at ridge_t={ridge_t}; \
cholesky error: {e}"
),
})?;
let mut min_diag = f64::INFINITY;
let mut max_diag = 0.0_f64;
for a in 0..d {
let v = factor[[a, a]];
if v < min_diag {
min_diag = v;
}
if v > max_diag {
max_diag = v;
}
}
if min_diag > 0.0 && max_diag.is_finite() {
let ratio = max_diag / min_diag;
let kappa_est = ratio * ratio;
let d_scale = (d as f64).max(1.0);
let kappa_max = 1.0 / (f64::EPSILON.sqrt() * d_scale);
if !kappa_est.is_finite() || kappa_est > kappa_max {
return Err(ArrowSchurError::PerRowFactorIllConditioned {
row: row_idx,
kappa_estimate: kappa_est,
});
}
} else {
return Err(ArrowSchurError::PerRowFactorIllConditioned {
row: row_idx,
kappa_estimate: f64::INFINITY,
});
}
Ok(factor)
}
fn manifold_mode_fingerprint(latent: &LatentCoordValues) -> u64 {
let manifold = latent.manifold();
if manifold.is_euclidean() {
return EUCLIDEAN_MANIFOLD_MODE_FINGERPRINT;
}
let mut hasher = StableHasher::new();
hasher.write_str("arrow-schur-manifold-mode-v1");
hasher.write_usize(latent.n_obs());
hasher.write_usize(latent.latent_dim());
write_latent_manifold(&mut hasher, manifold);
let mut metric_weights = Vec::new();
append_latent_metric_weights(&mut metric_weights, manifold);
hasher.write_usize(metric_weights.len());
for weight in metric_weights {
hasher.write_f64(weight);
}
hasher.finish_u64()
}
fn row_hessian_fingerprint_for_system(sys: &ArrowSchurSystem) -> u64 {
let mut hasher = StableHasher::new();
hasher.write_str("arrow-schur-row-hessian-v2");
hasher.write_usize(sys.rows.len());
hasher.write_usize(sys.d);
hasher.write_usize(sys.k);
for row in sys.rows.iter() {
write_array2_fingerprint(&mut hasher, &row.htt);
write_array2_fingerprint(&mut hasher, &row.htbeta);
}
write_array2_fingerprint(&mut hasher, &sys.hbb);
match sys.hbb_diag.as_ref() {
Some(diag) => {
hasher.write_bool(true);
hasher.write_usize(diag.len());
for &value in diag.iter() {
hasher.write_f64(value);
}
}
None => hasher.write_bool(false),
}
hasher.finish_u64()
}
fn combine_row_and_registry_fingerprints(row: u64, registry: u64) -> u64 {
if registry == 0 {
return row;
}
let mut hasher = StableHasher::new();
hasher.write_str("arrow-schur-row-hessian-with-penalties-v1");
hasher.write_u64(row);
hasher.write_u64(registry);
hasher.finish_u64()
}
fn stable_softplus_for_fingerprint(x: f64) -> f64 {
if x > 30.0 {
x
} else if x < -30.0 {
x.exp()
} else {
(1.0 + x.exp()).ln()
}
}
fn write_array2_fingerprint(hasher: &mut StableHasher, values: &Array2<f64>) {
hasher.write_usize(values.nrows());
hasher.write_usize(values.ncols());
for &value in values.iter() {
hasher.write_f64(value);
}
}
fn analytic_penalty_row_hessian_fingerprint(
penalty: &AnalyticPenaltyKind,
target_t: ArrayView1<'_, f64>,
rho_local: ArrayView1<'_, f64>,
) -> Option<u64> {
if penalty.tier() != PenaltyTier::Psi || !analytic_penalty_is_row_block_diagonal(penalty) {
return None;
}
let mut hasher = StableHasher::new();
hasher.write_str("arrow-schur-analytic-row-hessian-v1");
hasher.write_str(penalty.name());
hasher.write_usize(target_t.len());
hasher.write_usize(rho_local.len());
for &rho in rho_local.iter() {
hasher.write_f64(rho);
}
match penalty {
AnalyticPenaltyKind::RowPrecisionPrior(p) => {
let (n, rows, cols) = p.lambda_per_row.dim();
hasher.write_str("row-precision-fixed");
hasher.write_usize(n);
hasher.write_usize(rows);
hasher.write_usize(cols);
hasher.write_f64(p.weight);
hasher.write_bool(p.learnable_weight);
if p.learnable_weight {
hasher.write_usize(p.rho_index);
hasher.write_f64(p.weight * rho_local[p.rho_index].exp());
}
for &value in p.lambda_per_row.iter() {
hasher.write_f64(value);
}
}
AnalyticPenaltyKind::ParametricRowPrecisionPrior(p) => {
let (aux_n, aux_dim) = p.aux.dim();
let (mu_rows, mu_cols) = p.mu.dim();
let weight_offset = p.log_alpha.len() + p.raw_beta.len() + p.mu.len();
hasher.write_str("row-precision-parametric");
hasher.write_usize(aux_n);
hasher.write_usize(aux_dim);
hasher.write_usize(mu_rows);
hasher.write_usize(mu_cols);
hasher.write_f64(p.weight);
hasher.write_bool(p.learnable_weight);
for &value in p.aux.iter() {
hasher.write_f64(value);
}
for k in 0..p.log_alpha.len() {
let active_log_alpha = p.log_alpha[k] + rho_local[k];
hasher.write_f64(p.log_alpha[k]);
hasher.write_f64(active_log_alpha);
hasher.write_f64(active_log_alpha.exp());
}
let raw_beta_offset = p.log_alpha.len();
for k in 0..p.raw_beta.len() {
let active_raw_beta = p.raw_beta[k] + rho_local[raw_beta_offset + k];
hasher.write_f64(p.raw_beta[k]);
hasher.write_f64(active_raw_beta);
hasher.write_f64(stable_softplus_for_fingerprint(active_raw_beta));
}
let mu_offset = p.log_alpha.len() + p.raw_beta.len();
for k in 0..p.mu.nrows() {
for a in 0..p.mu.ncols() {
let idx = mu_offset + k * p.aux.ncols() + a;
hasher.write_f64(p.mu[[k, a]]);
hasher.write_f64(p.mu[[k, a]] + rho_local[idx]);
}
}
if p.learnable_weight {
hasher.write_usize(weight_offset);
hasher.write_f64(p.weight * rho_local[weight_offset].exp());
}
}
_ => {
hasher.write_str("row-block-diagonal");
if let Some(diag) = penalty.hessian_diag(target_t, rho_local) {
hasher.write_usize(diag.len());
for &value in diag.iter() {
hasher.write_f64(value);
}
} else {
hasher.write_usize(0);
}
}
}
Some(hasher.finish_u64())
}
fn write_latent_manifold(hasher: &mut StableHasher, manifold: &LatentManifold) {
match manifold {
LatentManifold::Euclidean => {
hasher.write_str("euclidean");
}
LatentManifold::Circle { period } => {
hasher.write_str("circle");
hasher.write_f64(*period);
}
LatentManifold::Sphere { dim } => {
hasher.write_str("sphere");
hasher.write_usize(*dim);
}
LatentManifold::Interval { lo, hi } => {
hasher.write_str("interval");
hasher.write_f64(*lo);
hasher.write_f64(*hi);
}
LatentManifold::Product(parts) => {
hasher.write_str("product");
hasher.write_usize(parts.len());
for part in parts {
write_latent_manifold(hasher, part);
}
}
LatentManifold::ProductWithMetric { manifolds, weights } => {
hasher.write_str("product-with-metric");
hasher.write_usize(manifolds.len());
for part in manifolds {
write_latent_manifold(hasher, part);
}
hasher.write_usize(weights.len());
for weight in weights {
hasher.write_f64(*weight);
}
}
}
}
fn append_latent_metric_weights(out: &mut Vec<f64>, manifold: &LatentManifold) {
match manifold {
LatentManifold::Euclidean => out.push(1.0),
LatentManifold::Circle { period } => {
out.push(1.0 / (period * period));
}
LatentManifold::Sphere { dim } => {
let scale = std::f64::consts::PI;
for _ in 0..*dim {
out.push(1.0 / (scale * scale));
}
}
LatentManifold::Interval { lo, hi } => {
let scale = hi - lo;
out.push(1.0 / (scale * scale));
}
LatentManifold::Product(parts) => {
for part in parts {
append_latent_metric_weights(out, part);
}
}
LatentManifold::ProductWithMetric {
manifolds: _,
weights,
} => {
out.extend(weights.iter().copied());
}
}
}
#[derive(Debug, Clone)]
pub struct ArrowRowBlock {
pub htt: Array2<f64>,
pub htbeta: Array2<f64>,
pub gt: Array1<f64>,
}
impl ArrowRowBlock {
pub fn new(d: usize, k: usize) -> Self {
Self {
htt: Array2::<f64>::zeros((d, d)),
htbeta: Array2::<f64>::zeros((d, k)),
gt: Array1::<f64>::zeros(d),
}
}
}
pub struct ArrowSchurSystem {
pub rows: Vec<ArrowRowBlock>,
pub hbb: Array2<f64>,
pub hbb_matvec: Option<SharedBetaMatvec>,
pub htbeta_matvec: Option<RowHtbetaMatvec>,
pub hbb_diag: Option<Array1<f64>>,
pub gb: Array1<f64>,
pub d: usize,
pub k: usize,
pub manifold_mode_fingerprint: u64,
pub row_hessian_fingerprint: u64,
pub analytic_row_hessian_fingerprint: u64,
}
impl ArrowSchurSystem {
pub fn new(n: usize, d: usize, k: usize) -> Self {
let rows = (0..n).map(|_| ArrowRowBlock::new(d, k)).collect();
let mut sys = Self {
rows,
hbb: Array2::<f64>::zeros((k, k)),
hbb_matvec: None,
htbeta_matvec: None,
hbb_diag: None,
gb: Array1::<f64>::zeros(k),
d,
k,
manifold_mode_fingerprint: EUCLIDEAN_MANIFOLD_MODE_FINGERPRINT,
row_hessian_fingerprint: 0,
analytic_row_hessian_fingerprint: 0,
};
sys.refresh_row_hessian_fingerprint();
sys
}
pub fn new_matrix_free_shared<F>(
n: usize,
d: usize,
k: usize,
matvec: F,
diag: Array1<f64>,
) -> Self
where
F: for<'a> Fn(ArrayView1<'a, f64>, &mut Array1<f64>) + Send + Sync + 'static,
{
assert_eq!(diag.len(), k);
let rows = (0..n).map(|_| ArrowRowBlock::new(d, k)).collect();
let mut sys = Self {
rows,
hbb: Array2::<f64>::zeros((0, 0)),
hbb_matvec: Some(Arc::new(matvec)),
htbeta_matvec: None,
hbb_diag: Some(diag),
gb: Array1::<f64>::zeros(k),
d,
k,
manifold_mode_fingerprint: EUCLIDEAN_MANIFOLD_MODE_FINGERPRINT,
row_hessian_fingerprint: 0,
analytic_row_hessian_fingerprint: 0,
};
sys.refresh_row_hessian_fingerprint();
sys
}
pub fn n(&self) -> usize {
self.rows.len()
}
pub fn compute_row_hessian_fingerprint(&self) -> u64 {
row_hessian_fingerprint_for_system(self)
}
pub fn current_row_hessian_fingerprint(&self) -> u64 {
combine_row_and_registry_fingerprints(
self.compute_row_hessian_fingerprint(),
self.analytic_row_hessian_fingerprint,
)
}
pub fn refresh_row_hessian_fingerprint(&mut self) {
self.row_hessian_fingerprint = self.current_row_hessian_fingerprint();
}
pub fn set_shared_beta_operator<F>(&mut self, matvec: F, diag: Array1<f64>)
where
F: for<'a> Fn(ArrayView1<'a, f64>, &mut Array1<f64>) + Send + Sync + 'static,
{
assert_eq!(diag.len(), self.k);
self.hbb_matvec = Some(Arc::new(matvec));
self.hbb_diag = Some(diag);
self.refresh_row_hessian_fingerprint();
}
pub fn set_row_htbeta_operator<F>(&mut self, matvec: F)
where
F: for<'a> Fn(usize, ArrayView1<'a, f64>, &mut Array1<f64>) + Send + Sync + 'static,
{
self.htbeta_matvec = Some(Arc::new(matvec));
self.refresh_row_hessian_fingerprint();
}
pub fn add_analytic_penalty_contributions(
&mut self,
registry: &AnalyticPenaltyRegistry,
target_t: ArrayView1<'_, f64>,
target_beta: ArrayView1<'_, f64>,
rho_global: ArrayView1<'_, f64>,
) -> Result<(), ArrowSchurError> {
let layout = registry.rho_layout();
let mut penalty_fingerprints = Vec::new();
for (penalty, (rho_slice, tier, name)) in registry.penalties.iter().zip(layout.iter()) {
let rho_local = rho_global.slice(ndarray::s![rho_slice.clone()]);
match tier {
PenaltyTier::Psi => {
if !analytic_penalty_is_row_block_diagonal(penalty) {
return Err(ArrowSchurError::SchurFactorFailed {
reason: format!(
"analytic penalty {name:?} couples latent rows; cross-row Hessian contributions are not yet supported on any production solver path. Consider using a row-block-only penalty (ARDPenalty, SparsityPenalty, SoftmaxAssignmentSparsity, IBPAssignment) or filing an issue requesting cross-row Hessian support."
),
});
}
self.add_ext_coord_penalty(penalty, target_t, rho_local);
if let Some(fingerprint) =
analytic_penalty_row_hessian_fingerprint(penalty, target_t, rho_local)
{
penalty_fingerprints.push(fingerprint);
}
}
PenaltyTier::Beta => {
self.add_beta_penalty(penalty, target_beta, rho_local);
}
PenaltyTier::Rho => {
}
}
}
self.analytic_row_hessian_fingerprint = if penalty_fingerprints.is_empty() {
0
} else {
let mut hasher = StableHasher::new();
hasher.write_str("arrow-schur-row-hessian-registry-v1");
hasher.write_usize(penalty_fingerprints.len());
for fingerprint in penalty_fingerprints {
hasher.write_u64(fingerprint);
}
hasher.finish_u64()
};
self.refresh_row_hessian_fingerprint();
Ok(())
}
pub fn apply_riemannian_latent_geometry(&mut self, latent: &LatentCoordValues) {
let manifold = latent.manifold();
self.manifold_mode_fingerprint = manifold_mode_fingerprint(latent);
if manifold.is_euclidean() {
self.refresh_row_hessian_fingerprint();
return;
}
assert_eq!(latent.n_obs(), self.rows.len());
assert_eq!(latent.latent_dim(), self.d);
for (i, row) in self.rows.iter_mut().enumerate() {
let t_i = ArrayView1::from(latent.row(i));
let gt_e = row.gt.clone();
let htt_e = row.htt.clone();
let htbeta_e = row.htbeta.clone();
row.gt = manifold.project_to_tangent(t_i, gt_e.view());
row.htt = manifold.riemannian_hessian_matrix(t_i, gt_e.view(), htt_e.view());
row.htbeta = manifold.project_matrix_columns_to_tangent(t_i, htbeta_e.view());
}
self.refresh_row_hessian_fingerprint();
}
fn add_ext_coord_penalty(
&mut self,
penalty: &AnalyticPenaltyKind,
target_t: ArrayView1<'_, f64>,
rho_local: ArrayView1<'_, f64>,
) {
let d = self.d;
let n = self.rows.len();
apply_analytic_penalty(
penalty,
target_t,
rho_local,
n * d,
d,
self,
|sys, flat, value| sys.rows[flat / d].gt[flat % d] += value,
|sys, flat, value| sys.rows[flat / d].htt[[flat % d, flat % d]] += value,
|a, probe| {
for i in 0..n {
probe[i * d + a] = 1.0;
}
},
|sys, a, hv| {
for i in 0..n {
for b in 0..d {
sys.rows[i].htt[[b, a]] += hv[i * d + b];
}
}
},
);
}
fn add_beta_penalty(
&mut self,
penalty: &AnalyticPenaltyKind,
target_beta: ArrayView1<'_, f64>,
rho_local: ArrayView1<'_, f64>,
) {
let k = self.k;
let hvp_columns = if self.hbb.dim() == (k, k) { k } else { 0 };
apply_analytic_penalty(
penalty,
target_beta,
rho_local,
k,
hvp_columns,
self,
|sys, j, value| sys.gb[j] += value,
|sys, j, value| {
if sys.hbb.dim() == (k, k) {
sys.hbb[[j, j]] += value;
}
if let Some(hbb_diag) = sys.hbb_diag.as_mut() {
hbb_diag[j] += value;
}
},
|j, probe| probe[j] = 1.0,
|sys, j, hv| {
for i in 0..k {
sys.hbb[[i, j]] += hv[i];
}
if let Some(hbb_diag) = sys.hbb_diag.as_mut() {
hbb_diag[j] += hv[j];
}
},
);
}
pub fn solve(
&self,
ridge_t: f64,
ridge_beta: f64,
) -> Result<(Array1<f64>, Array1<f64>), ArrowSchurError> {
let options = ArrowSolveOptions::automatic(self.k);
solve_arrow_newton_step_core(self, ridge_t, ridge_beta, &options)
}
pub fn solve_with_lm_escalation(
&self,
ridge_t: f64,
ridge_beta: f64,
) -> Result<(Array1<f64>, Array1<f64>), ArrowSchurError> {
let options = ArrowSolveOptions::automatic(self.k);
solve_with_lm_escalation_inner(self, ridge_t, ridge_beta, &options)
}
pub fn solve_with_options(
&self,
ridge_t: f64,
ridge_beta: f64,
options: &ArrowSolveOptions,
) -> Result<(Array1<f64>, Array1<f64>), ArrowSchurError> {
solve_arrow_newton_step_core(self, ridge_t, ridge_beta, options)
}
}
pub struct StreamingArrowSchur {
pub n_rows: usize,
pub d: usize,
pub k: usize,
pub chunk_size: usize,
pub s_acc: Array2<f64>,
rhs_acc: Array1<f64>,
hbb: Array2<f64>,
gb: Array1<f64>,
row_builder: StreamingArrowRowBuilder,
}
impl std::fmt::Debug for StreamingArrowSchur {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("StreamingArrowSchur")
.field("n_rows", &self.n_rows)
.field("d", &self.d)
.field("k", &self.k)
.field("chunk_size", &self.chunk_size)
.finish_non_exhaustive()
}
}
impl StreamingArrowSchur {
#[must_use]
pub fn new(
n_rows: usize,
d: usize,
k: usize,
hbb: Array2<f64>,
gb: Array1<f64>,
row_builder: StreamingArrowRowBuilder,
chunk_size: usize,
) -> Self {
assert_eq!(hbb.dim(), (k, k));
assert_eq!(gb.len(), k);
Self {
n_rows,
d,
k,
chunk_size: chunk_size.max(1),
s_acc: Array2::<f64>::zeros((k, k)),
rhs_acc: Array1::<f64>::zeros(k),
hbb,
gb,
row_builder,
}
}
#[must_use]
pub fn from_system(sys: &ArrowSchurSystem, chunk_size: usize) -> Self {
let rows = Arc::new(sys.rows.clone());
let row_builder: StreamingArrowRowBuilder = Arc::new(move |row| {
rows.get(row)
.cloned()
.ok_or_else(|| ArrowSchurError::SchurFactorFailed {
reason: format!("streaming row {row} out of bounds"),
})
});
Self::new(
sys.rows.len(),
sys.d,
sys.k,
sys.hbb.clone(),
sys.gb.clone(),
row_builder,
chunk_size,
)
}
pub fn reset_accumulator(&mut self, ridge_beta: f64) -> Result<(), ArrowSchurError> {
if self.hbb.dim() != (self.k, self.k) {
return Err(ArrowSchurError::SchurFactorFailed {
reason: "streaming Arrow-Schur requires a dense beta block accumulator".to_string(),
});
}
self.s_acc.assign(&self.hbb);
for j in 0..self.k {
self.s_acc[[j, j]] += ridge_beta;
self.rhs_acc[j] = 0.0;
}
Ok(())
}
pub fn accumulate_chunk(
&mut self,
start: usize,
end: usize,
ridge_t: f64,
mode: ArrowSolverMode,
) -> Result<(), ArrowSchurError> {
if start > end || end > self.n_rows {
return Err(ArrowSchurError::SchurFactorFailed {
reason: format!(
"streaming Arrow-Schur chunk [{start}, {end}) outside 0..{}",
self.n_rows
),
});
}
let backend = CpuBatchedBlockSolver;
for row_idx in start..end {
let row = (self.row_builder)(row_idx)?;
self.validate_row(row_idx, &row)?;
let factor = factor_one_row(&row, ridge_t, self.d, row_idx)?;
let v = backend.solve_block_vector(&factor, &row.gt);
for c in 0..self.d {
let vc = v[c];
if vc == 0.0 {
continue;
}
for a in 0..self.k {
self.rhs_acc[a] += row.htbeta[[c, a]] * vc;
}
}
match mode {
ArrowSolverMode::Direct => {
let solved = backend.solve_block_matrix(&factor, &row.htbeta);
backend.block_gemm_subtract(&mut self.s_acc, &row.htbeta, &solved);
}
ArrowSolverMode::SqrtBA => {
let whitened = backend.sqrt_solve_block_matrix(&factor, &row.htbeta);
backend.block_gemm_subtract(&mut self.s_acc, &whitened, &whitened);
}
ArrowSolverMode::InexactPCG => {
return Err(ArrowSchurError::PcgFailed {
reason: "streaming Arrow-Schur accumulator is for dense direct modes; use matrix-free PCG without streaming_chunk_size".to_string(),
});
}
}
}
Ok(())
}
pub fn solve(
&mut self,
ridge_t: f64,
ridge_beta: f64,
options: &ArrowSolveOptions,
) -> Result<(Array1<f64>, Array1<f64>, Option<Array2<f64>>), ArrowSchurError> {
self.reset_accumulator(ridge_beta)?;
for start in (0..self.n_rows).step_by(self.chunk_size) {
let end = (start + self.chunk_size).min(self.n_rows);
self.accumulate_chunk(start, end, ridge_t, options.mode)?;
}
for j in 0..self.k {
self.rhs_acc[j] -= self.gb[j];
}
symmetrize_upper_from_lower(&mut self.s_acc);
let trust_metric_weights = None;
let (delta_beta, schur_factor) =
solve_dense_reduced_system(&self.s_acc, &self.rhs_acc, options, trust_metric_weights)?;
let delta_t = self.back_substitute(ridge_t, delta_beta.view())?;
Ok((delta_t, delta_beta, schur_factor))
}
fn back_substitute(
&self,
ridge_t: f64,
delta_beta: ArrayView1<'_, f64>,
) -> Result<Array1<f64>, ArrowSchurError> {
let backend = CpuBatchedBlockSolver;
let mut delta_t = Array1::<f64>::zeros(self.n_rows * self.d);
let mut rhs = Array1::<f64>::zeros(self.d);
for start in (0..self.n_rows).step_by(self.chunk_size) {
let end = (start + self.chunk_size).min(self.n_rows);
for row_idx in start..end {
let row = (self.row_builder)(row_idx)?;
self.validate_row(row_idx, &row)?;
let factor = factor_one_row(&row, ridge_t, self.d, row_idx)?;
for c in 0..self.d {
let mut acc = row.gt[c];
for a in 0..self.k {
acc += row.htbeta[[c, a]] * delta_beta[a];
}
rhs[c] = acc;
}
let dt_i = backend.solve_block_vector(&factor, &rhs);
let row_base = row_idx * self.d;
for c in 0..self.d {
delta_t[row_base + c] = -dt_i[c];
}
}
}
Ok(delta_t)
}
fn validate_row(&self, row_idx: usize, row: &ArrowRowBlock) -> Result<(), ArrowSchurError> {
if row.htt.dim() != (self.d, self.d) {
return Err(ArrowSchurError::PerRowFactorFailed {
row: row_idx,
reason: format!(
"streaming row H_tt shape {:?} != ({}, {})",
row.htt.dim(),
self.d,
self.d
),
});
}
if row.htbeta.dim() != (self.d, self.k) {
return Err(ArrowSchurError::SchurFactorFailed {
reason: format!(
"streaming row H_tβ shape {:?} != ({}, {})",
row.htbeta.dim(),
self.d,
self.k
),
});
}
if row.gt.len() != self.d {
return Err(ArrowSchurError::PerRowFactorFailed {
row: row_idx,
reason: format!("streaming row g_t length {} != {}", row.gt.len(), self.d),
});
}
Ok::<(), _>(())
}
}
fn apply_analytic_penalty<S, G, D, P, H>(
penalty: &AnalyticPenaltyKind,
target: ArrayView1<'_, f64>,
rho_local: ArrayView1<'_, f64>,
expected_target_len: usize,
hvp_columns: usize,
scatter_target: &mut S,
mut grad_scatter: G,
mut diag_scatter: D,
seed_hvp_probe: P,
mut hvp_column_scatter: H,
) where
G: FnMut(&mut S, usize, f64),
D: FnMut(&mut S, usize, f64),
P: Fn(usize, &mut Array1<f64>),
H: for<'a> FnMut(&mut S, usize, ArrayView1<'a, f64>),
{
assert_eq!(target.len(), expected_target_len);
let grad = penalty.grad_target(target, rho_local);
for index in 0..expected_target_len {
grad_scatter(scatter_target, index, grad[index]);
}
if let Some(diag) = penalty.hessian_diag(target, rho_local) {
assert_eq!(diag.len(), expected_target_len);
for index in 0..expected_target_len {
diag_scatter(scatter_target, index, diag[index]);
}
return;
}
let mut probe = Array1::<f64>::zeros(expected_target_len);
for column in 0..hvp_columns {
probe.fill(0.0);
seed_hvp_probe(column, &mut probe);
let hv = penalty.hvp(target, rho_local, probe.view());
hvp_column_scatter(scatter_target, column, hv.view());
}
}
fn analytic_penalty_is_row_block_diagonal(penalty: &AnalyticPenaltyKind) -> bool {
penalty.is_row_block_diagonal()
}
#[derive(Clone)]
pub enum ArrowUndampedFactors {
SameAsDamped,
Owned(Arc<[Array2<f64>]>),
}
impl std::fmt::Debug for ArrowUndampedFactors {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::SameAsDamped => f.write_str("SameAsDamped"),
Self::Owned(factors) => f.debug_tuple("Owned").field(&factors.len()).finish(),
}
}
}
#[derive(Clone)]
pub enum ArrowHtbetaCache {
Dense {
blocks: Arc<[Array2<f64>]>,
estimated_bytes: usize,
},
Matvec {
op: RowHtbetaMatvec,
estimated_bytes: usize,
},
Disabled {
estimated_bytes: usize,
},
}
impl std::fmt::Debug for ArrowHtbetaCache {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Dense {
blocks,
estimated_bytes,
} => f
.debug_struct("Dense")
.field("blocks", &blocks.len())
.field("estimated_bytes", estimated_bytes)
.finish(),
Self::Matvec {
estimated_bytes, ..
} => f
.debug_struct("Matvec")
.field("estimated_bytes", estimated_bytes)
.finish(),
Self::Disabled { estimated_bytes } => f
.debug_struct("Disabled")
.field("estimated_bytes", estimated_bytes)
.finish(),
}
}
}
impl ArrowHtbetaCache {
fn is_available(&self) -> bool {
!matches!(self, Self::Disabled { .. })
}
fn apply_row(
&self,
row: usize,
delta_beta: ArrayView1<'_, f64>,
out: &mut Array1<f64>,
) -> bool {
match self {
Self::Dense { blocks, .. } => {
let Some(block) = blocks.get(row) else {
return false;
};
if block.ncols() != delta_beta.len() || block.nrows() != out.len() {
return false;
}
for c in 0..block.nrows() {
let mut acc = 0.0_f64;
for a in 0..block.ncols() {
acc += block[[c, a]] * delta_beta[a];
}
out[c] = acc;
}
true
}
Self::Matvec { op, .. } => {
op(row, delta_beta, out);
true
}
Self::Disabled { .. } => false,
}
}
}
#[derive(Debug, Clone)]
pub struct ArrowFactorCache {
pub htt_factors: Arc<[Array2<f64>]>,
pub htt_factors_undamped: ArrowUndampedFactors,
pub schur_factor: Option<Array2<f64>>,
pub solver_mode: ArrowSolverMode,
pub ridge_t: f64,
pub ridge_beta: f64,
pub htbeta: ArrowHtbetaCache,
pub d: usize,
pub k: usize,
pub manifold_mode_fingerprint: u64,
pub row_hessian_fingerprint: u64,
}
impl ArrowFactorCache {
pub fn n_rows(&self) -> usize {
self.htt_factors.len()
}
pub fn htbeta_available(&self) -> bool {
self.htbeta.is_available()
}
pub fn undamped_factor(&self, row: usize) -> &Array2<f64> {
match &self.htt_factors_undamped {
ArrowUndampedFactors::SameAsDamped => &self.htt_factors[row],
ArrowUndampedFactors::Owned(factors) => &factors[row],
}
}
pub fn undamped_factor_count(&self) -> usize {
match &self.htt_factors_undamped {
ArrowUndampedFactors::SameAsDamped => self.htt_factors.len(),
ArrowUndampedFactors::Owned(factors) => factors.len(),
}
}
pub fn undamped_factors_iter(&self) -> impl Iterator<Item = &Array2<f64>> {
(0..self.undamped_factor_count()).map(|row| self.undamped_factor(row))
}
pub fn apply_htbeta_row(
&self,
row: usize,
delta_beta: ArrayView1<'_, f64>,
out: &mut Array1<f64>,
) -> bool {
if out.len() != self.d || delta_beta.len() != self.k {
return false;
}
self.htbeta.apply_row(row, delta_beta, out)
}
pub fn predict_delta_t_from_delta_beta(&self, delta_beta: ArrayView1<'_, f64>) -> Array1<f64> {
let n = self.undamped_factor_count();
let d = self.d;
assert_eq!(delta_beta.len(), self.k);
if !self.htbeta_available() {
return Array1::<f64>::zeros(n * d);
}
let mut out = Array1::<f64>::zeros(n * d);
let mut rhs = Array1::<f64>::zeros(d);
for i in 0..n {
if !self.apply_htbeta_row(i, delta_beta.view(), &mut rhs) {
return Array1::<f64>::zeros(n * d);
}
let v = chol_solve_vector(self.undamped_factor(i), &rhs);
for c in 0..d {
out[i * d + c] = -v[c];
}
}
out
}
pub fn predict_delta_t_combined(
&self,
delta_beta: Option<ArrayView1<'_, f64>>,
delta_gt: Option<ArrayView1<'_, f64>>,
) -> Array1<f64> {
let n = self.undamped_factor_count();
let d = self.d;
if let Some(db) = delta_beta.as_ref() {
assert_eq!(db.len(), self.k);
}
if let Some(dg) = delta_gt.as_ref() {
assert_eq!(dg.len(), n * d);
}
let mut out = Array1::<f64>::zeros(n * d);
let mut rhs = Array1::<f64>::zeros(d);
let mut htbeta_delta = Array1::<f64>::zeros(d);
for i in 0..n {
for c in 0..d {
rhs[c] = 0.0;
}
if let Some(db) = delta_beta.as_ref() {
htbeta_delta.fill(0.0);
if !self.apply_htbeta_row(i, db.view(), &mut htbeta_delta) {
return Array1::<f64>::zeros(n * d);
}
for c in 0..d {
rhs[c] += htbeta_delta[c];
}
}
if let Some(dg) = delta_gt.as_ref() {
for c in 0..d {
rhs[c] += dg[i * d + c];
}
}
let v = chol_solve_vector(self.undamped_factor(i), &rhs);
for c in 0..d {
out[i * d + c] = -v[c];
}
}
out
}
pub fn arrow_log_det(&self) -> (f64, Option<f64>) {
let mut log_det_tt = 0.0_f64;
for l in self.htt_factors.iter() {
for i in 0..l.nrows() {
log_det_tt += l[[i, i]].ln();
}
}
log_det_tt *= 2.0;
let log_det_schur = self.schur_factor.as_ref().map(|l| {
let mut s = 0.0_f64;
for i in 0..l.nrows() {
s += l[[i, i]].ln();
}
2.0 * s
});
(log_det_tt, log_det_schur)
}
pub fn predict_delta_t_from_delta_gt(&self, delta_gt: ArrayView1<'_, f64>) -> Array1<f64> {
let n = self.undamped_factor_count();
let d = self.d;
assert_eq!(delta_gt.len(), n * d);
assert_eq!(
self.undamped_factor_count(),
n,
"undamped factor cache and N must agree"
);
let mut out = Array1::<f64>::zeros(n * d);
let mut rhs = Array1::<f64>::zeros(d);
for i in 0..n {
for c in 0..d {
rhs[c] = delta_gt[i * d + c];
}
let v = chol_solve_vector(self.undamped_factor(i), &rhs);
for c in 0..d {
out[i * d + c] = -v[c];
}
}
out
}
}
pub fn solve_arrow_newton_step_with_options(
sys: &ArrowSchurSystem,
ridge_t: f64,
ridge_beta: f64,
options: &ArrowSolveOptions,
) -> Result<(Array1<f64>, Array1<f64>, ArrowFactorCache), ArrowSchurError> {
if options.streaming_chunk_size.is_some() {
return Err(ArrowSchurError::SchurFactorFailed {
reason: "streaming Arrow-Schur solve does not materialize the factor cache required by this entry point".to_string(),
});
}
let step = solve_arrow_newton_step_artifacts(sys, ridge_t, ridge_beta, options)?;
let backend = CpuBatchedBlockSolver;
let htbeta_estimated_bytes =
estimated_htbeta_bytes(sys.rows.len(), sys.d, sys.k).unwrap_or(usize::MAX);
let htbeta = if let Some(op) = sys.htbeta_matvec.as_ref() {
ArrowHtbetaCache::Matvec {
op: Arc::clone(op),
estimated_bytes: htbeta_estimated_bytes,
}
} else if htbeta_estimated_bytes <= ARROW_FACTOR_CACHE_HTBETA_BUDGET_BYTES {
ArrowHtbetaCache::Dense {
blocks: sys
.rows
.iter()
.map(|r| r.htbeta.clone())
.collect::<Vec<_>>()
.into(),
estimated_bytes: htbeta_estimated_bytes,
}
} else {
ArrowHtbetaCache::Disabled {
estimated_bytes: htbeta_estimated_bytes,
}
};
let htt_factors = Arc::<[Array2<f64>]>::from(step.htt_factors);
let htt_factors_undamped = if ridge_t == 0.0 {
ArrowUndampedFactors::SameAsDamped
} else {
ArrowUndampedFactors::Owned(backend.factor_blocks(&sys.rows, 0.0, sys.d)?.into())
};
let cache = ArrowFactorCache {
htt_factors,
htt_factors_undamped,
schur_factor: step.schur_factor,
solver_mode: options.mode,
ridge_t,
ridge_beta,
htbeta,
d: sys.d,
k: sys.k,
manifold_mode_fingerprint: sys.manifold_mode_fingerprint,
row_hessian_fingerprint: sys.current_row_hessian_fingerprint(),
};
Ok((step.delta_t, step.delta_beta, cache))
}
fn estimated_htbeta_bytes(n: usize, d: usize, k: usize) -> Option<usize> {
n.checked_mul(d)?
.checked_mul(k)?
.checked_mul(std::mem::size_of::<f64>())
}
pub fn solve_arrow_newton_step_core(
sys: &ArrowSchurSystem,
ridge_t: f64,
ridge_beta: f64,
options: &ArrowSolveOptions,
) -> Result<(Array1<f64>, Array1<f64>), ArrowSchurError> {
if let Some(chunk_size) = options.streaming_chunk_size {
let mut streaming = StreamingArrowSchur::from_system(sys, chunk_size);
return streaming
.solve(ridge_t, ridge_beta, options)
.map(|(delta_t, delta_beta, _)| (delta_t, delta_beta));
}
solve_arrow_newton_step_artifacts(sys, ridge_t, ridge_beta, options)
.map(|step| (step.delta_t, step.delta_beta))
}
pub fn solve_with_lm_escalation_inner(
sys: &ArrowSchurSystem,
ridge_t: f64,
ridge_beta: f64,
options: &ArrowSolveOptions,
) -> Result<(Array1<f64>, Array1<f64>), ArrowSchurError> {
let mut proximal_ridge = 0.0_f64;
let mut last_err: Option<ArrowSchurError> = None;
for attempt in 0..=DEFAULT_PROXIMAL_MAX_ATTEMPTS {
let damped_ridge_t = ridge_t + proximal_ridge;
let damped_ridge_beta = ridge_beta + proximal_ridge;
match solve_arrow_newton_step_core(sys, damped_ridge_t, damped_ridge_beta, options) {
Ok(pair) => return Ok(pair),
Err(err) => {
let recoverable = matches!(
err,
ArrowSchurError::PerRowFactorFailed { .. }
| ArrowSchurError::PerRowFactorIllConditioned { .. }
| ArrowSchurError::SchurFactorFailed { .. }
);
last_err = Some(err);
if !recoverable {
break;
}
if attempt == DEFAULT_PROXIMAL_MAX_ATTEMPTS {
break;
}
proximal_ridge = if proximal_ridge == 0.0 {
DEFAULT_PROXIMAL_INITIAL_RIDGE
} else {
proximal_ridge * DEFAULT_PROXIMAL_RIDGE_GROWTH
};
}
}
}
Err(last_err.expect("escalation loop set last_err on failure"))
}
pub fn solve_arrow_newton_step_with_proximal_correction<F>(
sys: &ArrowSchurSystem,
base_ridge_t: f64,
base_ridge_beta: f64,
current_objective_value: f64,
options: &ArrowSolveOptions,
correction: &ArrowProximalCorrectionOptions,
mut trial_objective: F,
) -> Result<ArrowAcceptedProximalStep, ArrowSchurError>
where
F: for<'a, 'b> FnMut(ArrayView1<'a, f64>, ArrayView1<'b, f64>) -> f64,
{
if !current_objective_value.is_finite() {
return Err(ArrowSchurError::AdaptiveCorrectionFailed {
reason: "current objective is not finite".to_string(),
});
}
if !(correction.ridge_growth.is_finite() && correction.ridge_growth > 1.0) {
return Err(ArrowSchurError::AdaptiveCorrectionFailed {
reason: format!(
"ridge_growth must be finite and > 1; got {}",
correction.ridge_growth
),
});
}
if !(correction.armijo_c1.is_finite()
&& correction.armijo_c1 > 0.0
&& correction.armijo_c1 < 1.0)
{
return Err(ArrowSchurError::AdaptiveCorrectionFailed {
reason: format!("armijo_c1 must be in (0, 1); got {}", correction.armijo_c1),
});
}
let grad_norm = arrow_gradient_norm(sys);
if grad_norm <= correction.gradient_tolerance.max(0.0) {
return Ok(ArrowAcceptedProximalStep {
delta_t: Array1::<f64>::zeros(sys.rows.len() * sys.d),
delta_beta: Array1::<f64>::zeros(sys.k),
ridge_t: base_ridge_t,
ridge_beta: base_ridge_beta,
proximal_ridge: 0.0,
objective_value: current_objective_value,
trial_objective_value: current_objective_value,
gradient_dot_step: 0.0,
attempts: 0,
});
}
let mut proximal_ridge = correction.initial_ridge.max(0.0);
let mut last_reason = String::from("no attempts were made");
for attempt in 0..correction.max_attempts {
let ridge_t = base_ridge_t + proximal_ridge;
let ridge_beta = base_ridge_beta + proximal_ridge;
match solve_arrow_newton_step_core(sys, ridge_t, ridge_beta, options) {
Ok((delta_t, delta_beta)) => {
let g_dot_p = arrow_gradient_dot_step(sys, delta_t.view(), delta_beta.view());
if !(g_dot_p.is_finite() && g_dot_p < 0.0) {
last_reason =
format!("candidate was not a finite descent direction: g·p={g_dot_p}");
} else {
let trial_value = trial_objective(delta_t.view(), delta_beta.view());
let armijo_bound = current_objective_value + correction.armijo_c1 * g_dot_p;
if trial_value.is_finite() && trial_value <= armijo_bound {
return Ok(ArrowAcceptedProximalStep {
delta_t,
delta_beta,
ridge_t,
ridge_beta,
proximal_ridge,
objective_value: current_objective_value,
trial_objective_value: trial_value,
gradient_dot_step: g_dot_p,
attempts: attempt + 1,
});
}
last_reason = format!(
"Armijo rejected trial objective {trial_value}; bound {armijo_bound}"
);
}
}
Err(err) => {
last_reason = err.to_string();
}
}
proximal_ridge = next_proximal_ridge(proximal_ridge, correction.ridge_growth);
}
Err(ArrowSchurError::AdaptiveCorrectionFailed {
reason: format!(
"failed after {} attempts; last rejection: {last_reason}",
correction.max_attempts
),
})
}
pub fn arrow_quadratic_model_reduction(
sys: &ArrowSchurSystem,
delta_t: ArrayView1<'_, f64>,
delta_beta: ArrayView1<'_, f64>,
ridge_t: f64,
ridge_beta: f64,
) -> Result<f64, ArrowSchurError> {
assert_eq!(delta_t.len(), sys.rows.len() * sys.d);
assert_eq!(delta_beta.len(), sys.k);
let mut lin = sys.gb.dot(&delta_beta);
let mut quad = ridge_beta * delta_beta.dot(&delta_beta);
let mut hbb_delta = Array1::<f64>::zeros(sys.k);
if let Some(hbb_matvec) = sys.hbb_matvec.as_ref() {
hbb_matvec(delta_beta, &mut hbb_delta);
} else if sys.hbb.dim() == (sys.k, sys.k) {
hbb_delta.assign(&sys.hbb.dot(&delta_beta));
} else {
return Err(ArrowSchurError::SchurFactorFailed {
reason: "Arrow-Schur predicted reduction requires a dense H_ββ block or matrix-free H_ββ operator".to_string(),
});
}
quad += delta_beta.dot(&hbb_delta);
for (i, row) in sys.rows.iter().enumerate() {
let base = i * sys.d;
for c in 0..sys.d {
let dt_c = delta_t[base + c];
lin += row.gt[c] * dt_c;
quad += ridge_t * dt_c * dt_c;
for r in 0..sys.d {
quad += dt_c * row.htt[[c, r]] * delta_t[base + r];
}
for b in 0..sys.k {
quad += 2.0 * dt_c * row.htbeta[[c, b]] * delta_beta[b];
}
}
}
Ok(-(lin + 0.5 * quad))
}
fn next_proximal_ridge(current: f64, growth: f64) -> f64 {
if current > 0.0 {
current * growth
} else {
DEFAULT_PROXIMAL_INITIAL_RIDGE
}
}
fn arrow_gradient_norm(sys: &ArrowSchurSystem) -> f64 {
let mut sum = 0.0;
for row in sys.rows.iter() {
for &v in row.gt.iter() {
sum += v * v;
}
}
for &v in sys.gb.iter() {
sum += v * v;
}
sum.sqrt()
}
fn arrow_gradient_dot_step(
sys: &ArrowSchurSystem,
delta_t: ArrayView1<'_, f64>,
delta_beta: ArrayView1<'_, f64>,
) -> f64 {
assert_eq!(delta_t.len(), sys.rows.len() * sys.d);
assert_eq!(delta_beta.len(), sys.k);
let mut out = 0.0;
for (i, row) in sys.rows.iter().enumerate() {
for c in 0..sys.d {
out += row.gt[c] * delta_t[i * sys.d + c];
}
}
for a in 0..sys.k {
out += sys.gb[a] * delta_beta[a];
}
out
}
struct ArrowNewtonStepArtifacts {
delta_t: Array1<f64>,
delta_beta: Array1<f64>,
htt_factors: Vec<Array2<f64>>,
schur_factor: Option<Array2<f64>>,
}
fn solve_arrow_newton_step_artifacts(
sys: &ArrowSchurSystem,
ridge_t: f64,
ridge_beta: f64,
options: &ArrowSolveOptions,
) -> Result<ArrowNewtonStepArtifacts, ArrowSchurError> {
if let Some(chunk_size) = options.streaming_chunk_size {
let mut streaming = StreamingArrowSchur::from_system(sys, chunk_size);
let (delta_t, delta_beta, schur_factor) = streaming.solve(ridge_t, ridge_beta, options)?;
return Ok(ArrowNewtonStepArtifacts {
delta_t,
delta_beta,
htt_factors: Vec::new(),
schur_factor,
});
}
let n = sys.rows.len();
let d = sys.d;
let k = sys.k;
let backend = CpuBatchedBlockSolver;
let htt_factors = backend.factor_blocks(&sys.rows, ridge_t, d)?;
let rhs_beta = reduced_rhs_beta(sys, &htt_factors, &backend);
let trust_metric_weights = None;
let (delta_beta, schur_factor) = match options.mode {
ArrowSolverMode::Direct => {
let schur = build_dense_schur_direct(sys, &htt_factors, ridge_beta, &backend)?;
solve_dense_reduced_system(&schur, &rhs_beta, options, trust_metric_weights)?
}
ArrowSolverMode::SqrtBA => {
let schur = build_dense_schur_sqrt_ba(sys, &htt_factors, ridge_beta, &backend)?;
solve_dense_reduced_system(&schur, &rhs_beta, options, trust_metric_weights)?
}
ArrowSolverMode::InexactPCG => {
let preconditioner =
JacobiPreconditioner::from_arrow_schur(sys, &htt_factors, ridge_beta, &backend)?;
let delta = steihaug_pcg_reduced_system(
sys,
&htt_factors,
ridge_beta,
&rhs_beta,
&preconditioner,
&options.pcg,
&options.trust_region,
&backend,
trust_metric_weights,
)?;
(delta, None)
}
};
let mut delta_t = Array1::<f64>::zeros(n * d);
let mut rhs = Array1::<f64>::zeros(d);
for i in 0..n {
assert_eq!(sys.rows[i].gt.len(), d);
assert_eq!(sys.rows[i].htbeta.dim(), (d, k));
for c in 0..d {
let mut acc = sys.rows[i].gt[c];
for a in 0..k {
acc += sys.rows[i].htbeta[[c, a]] * delta_beta[a];
}
rhs[c] = acc;
}
let dt_i = backend.solve_block_vector(&htt_factors[i], &rhs);
for c in 0..d {
delta_t[i * d + c] = -dt_i[c];
}
}
Ok(ArrowNewtonStepArtifacts {
delta_t,
delta_beta,
htt_factors,
schur_factor,
})
}
fn reduced_rhs_beta<B: BatchedBlockSolver>(
sys: &ArrowSchurSystem,
htt_factors: &[Array2<f64>],
backend: &B,
) -> Array1<f64> {
let k = sys.k;
let d = sys.d;
let mut rhs_beta = Array1::<f64>::zeros(k);
for (i, row) in sys.rows.iter().enumerate() {
assert_eq!(row.htbeta.dim(), (d, k));
let v = backend.solve_block_vector(&htt_factors[i], &row.gt);
for c in 0..d {
let vc = v[c];
if vc == 0.0 {
continue;
}
for a in 0..k {
rhs_beta[a] += row.htbeta[[c, a]] * vc;
}
}
}
for j in 0..k {
rhs_beta[j] -= sys.gb[j];
}
rhs_beta
}
fn build_dense_schur_direct<B: BatchedBlockSolver>(
sys: &ArrowSchurSystem,
htt_factors: &[Array2<f64>],
ridge_beta: f64,
backend: &B,
) -> Result<Array2<f64>, ArrowSchurError> {
let k = sys.k;
if sys.hbb.dim() != (k, k) {
return Err(ArrowSchurError::SchurFactorFailed {
reason: "Direct BA requires a dense K×K shared H_ββ block".to_string(),
});
}
let mut schur = sys.hbb.clone();
for j in 0..k {
schur[[j, j]] += ridge_beta;
}
for (i, row) in sys.rows.iter().enumerate() {
let solved = backend.solve_block_matrix(&htt_factors[i], &row.htbeta);
backend.block_gemm_subtract(&mut schur, &row.htbeta, &solved);
}
symmetrize_upper_from_lower(&mut schur);
Ok(schur)
}
fn build_dense_schur_sqrt_ba<B: BatchedBlockSolver>(
sys: &ArrowSchurSystem,
htt_factors: &[Array2<f64>],
ridge_beta: f64,
backend: &B,
) -> Result<Array2<f64>, ArrowSchurError> {
let k = sys.k;
if sys.hbb.dim() != (k, k) {
return Err(ArrowSchurError::SchurFactorFailed {
reason: "Square-Root BA direct solve requires a dense K×K shared H_ββ block"
.to_string(),
});
}
let mut schur = sys.hbb.clone();
for j in 0..k {
schur[[j, j]] += ridge_beta;
}
for (i, row) in sys.rows.iter().enumerate() {
let whitened = backend.sqrt_solve_block_matrix(&htt_factors[i], &row.htbeta);
backend.block_gemm_subtract(&mut schur, &whitened, &whitened);
}
symmetrize_upper_from_lower(&mut schur);
Ok(schur)
}
fn solve_dense_reduced_system(
schur: &Array2<f64>,
rhs_beta: &Array1<f64>,
options: &ArrowSolveOptions,
metric_weights: Option<&MetricWeights>,
) -> Result<(Array1<f64>, Option<Array2<f64>>), ArrowSchurError> {
let factor =
cholesky_lower(schur).map_err(|e| ArrowSchurError::SchurFactorFailed { reason: e })?;
let direct = chol_solve_vector(&factor, rhs_beta);
if step_inside_trust_region(direct.view(), options.trust_region.radius, metric_weights) {
return Ok((direct, Some(factor)));
}
let identity = IdentityPreconditioner;
let delta = steihaug_dense_system(
schur,
rhs_beta,
&identity,
&ArrowPcgOptions {
max_iterations: options.trust_region.max_iterations,
relative_tolerance: options.trust_region.steihaug_relative_tolerance,
},
&options.trust_region,
metric_weights,
)?;
Ok((delta, Some(factor)))
}
fn step_inside_trust_region(
step: ArrayView1<'_, f64>,
radius: f64,
metric_weights: Option<&MetricWeights>,
) -> bool {
!radius.is_finite() || metric_norm(step, metric_weights) <= radius
}
fn schur_matvec<B: BatchedBlockSolver>(
sys: &ArrowSchurSystem,
htt_factors: &[Array2<f64>],
ridge_beta: f64,
x: &Array1<f64>,
out: &mut Array1<f64>,
backend: &B,
) {
let k = sys.k;
let d = sys.d;
if let Some(hbb_matvec) = sys.hbb_matvec.as_ref() {
hbb_matvec(x.view(), out);
for a in 0..k {
out[a] += ridge_beta * x[a];
}
} else {
for a in 0..k {
let mut acc = ridge_beta * x[a];
for b in 0..k {
acc += sys.hbb[[a, b]] * x[b];
}
out[a] = acc;
}
}
let mut local = Array1::<f64>::zeros(d);
for (i, row) in sys.rows.iter().enumerate() {
assert_eq!(row.htbeta.dim(), (d, k));
for c in 0..d {
let mut acc = 0.0;
for a in 0..k {
acc += row.htbeta[[c, a]] * x[a];
}
local[c] = acc;
}
let solved = backend.solve_block_vector(&htt_factors[i], &local);
for c in 0..d {
let sc = solved[c];
if sc == 0.0 {
continue;
}
for a in 0..k {
out[a] -= row.htbeta[[c, a]] * sc;
}
}
}
}
#[derive(Debug, Clone)]
pub struct JacobiPreconditioner {
inverse_diag: Array1<f64>,
}
impl JacobiPreconditioner {
pub fn from_arrow_schur<B: BatchedBlockSolver>(
sys: &ArrowSchurSystem,
htt_factors: &[Array2<f64>],
ridge_beta: f64,
backend: &B,
) -> Result<Self, ArrowSchurError> {
let k = sys.k;
let d = sys.d;
let mut diag = Array1::<f64>::zeros(k);
for a in 0..k {
let base = match sys.hbb_diag.as_ref() {
Some(hbb_diag) => hbb_diag[a],
None => sys.hbb[[a, a]],
};
diag[a] = base + ridge_beta;
}
let mut col = Array1::<f64>::zeros(d);
for (i, row) in sys.rows.iter().enumerate() {
for a in 0..k {
for c in 0..d {
col[c] = row.htbeta[[c, a]];
}
let solved = backend.solve_block_vector(&htt_factors[i], &col);
let mut acc = 0.0;
for c in 0..d {
acc += col[c] * solved[c];
}
diag[a] -= acc;
}
}
let mut inverse_diag = Array1::<f64>::zeros(k);
for a in 0..k {
let v = diag[a];
if !v.is_finite() || v <= 1e-18 {
return Err(ArrowSchurError::PcgFailed {
reason: format!(
"invalid Schur Jacobi diagonal at index {a}: {v}; \
operator regularization is required"
),
});
}
inverse_diag[a] = 1.0 / v;
}
Ok(Self { inverse_diag })
}
fn apply(&self, r: &Array1<f64>) -> Array1<f64> {
let mut out = Array1::<f64>::zeros(r.len());
for i in 0..r.len() {
out[i] = self.inverse_diag[i] * r[i];
}
out
}
}
#[derive(Debug, Clone, Copy)]
struct IdentityPreconditioner;
impl IdentityPreconditioner {
fn apply(&self, r: &Array1<f64>) -> Array1<f64> {
r.clone()
}
}
fn steihaug_pcg_reduced_system<B: BatchedBlockSolver>(
sys: &ArrowSchurSystem,
htt_factors: &[Array2<f64>],
ridge_beta: f64,
rhs: &Array1<f64>,
preconditioner: &JacobiPreconditioner,
pcg: &ArrowPcgOptions,
trust: &ArrowTrustRegionOptions,
backend: &B,
metric_weights: Option<&MetricWeights>,
) -> Result<Array1<f64>, ArrowSchurError> {
steihaug_cg(
rhs,
|p, out| schur_matvec(sys, htt_factors, ridge_beta, p, out, backend),
|r| preconditioner.apply(r),
pcg.max_iterations.min(trust.max_iterations),
pcg.relative_tolerance
.max(trust.steihaug_relative_tolerance),
trust.radius,
metric_weights,
)
}
fn steihaug_dense_system(
schur: &Array2<f64>,
rhs: &Array1<f64>,
preconditioner: &IdentityPreconditioner,
pcg: &ArrowPcgOptions,
trust: &ArrowTrustRegionOptions,
metric_weights: Option<&MetricWeights>,
) -> Result<Array1<f64>, ArrowSchurError> {
steihaug_cg(
rhs,
|p, out| dense_matvec(schur, p, out),
|r| preconditioner.apply(r),
pcg.max_iterations,
pcg.relative_tolerance,
trust.radius,
metric_weights,
)
}
fn steihaug_cg<MatVec, ApplyPrec>(
rhs: &Array1<f64>,
mut matvec: MatVec,
mut apply_preconditioner: ApplyPrec,
max_iterations: usize,
relative_tolerance: f64,
trust_radius: f64,
metric_weights: Option<&MetricWeights>,
) -> Result<Array1<f64>, ArrowSchurError>
where
MatVec: FnMut(&Array1<f64>, &mut Array1<f64>),
ApplyPrec: FnMut(&Array1<f64>) -> Array1<f64>,
{
let n = rhs.len();
if let Some(weights) = metric_weights {
assert_eq!(
weights.len(),
n,
"Steihaug-CG metric weight length must match solve dimension"
);
}
let radius = if trust_radius.is_finite() && trust_radius > 0.0 {
trust_radius
} else {
f64::INFINITY
};
let rhs_norm = metric_norm(rhs.view(), metric_weights);
if rhs_norm == 0.0 {
return Ok(Array1::<f64>::zeros(n));
}
let tol = relative_tolerance.max(0.0) * rhs_norm;
let mut x = Array1::<f64>::zeros(n);
let mut r = rhs.clone();
let mut z = apply_preconditioner(&r);
let mut p = z.clone();
let mut rz = metric_dot(&r, &z, metric_weights);
if rz <= 0.0 || !rz.is_finite() {
if radius.is_finite() {
return Ok(step_to_trust_boundary(&x, &r, radius, metric_weights));
}
return Err(ArrowSchurError::PcgFailed {
reason: "non-positive preconditioned residual in Schur PCG".to_string(),
});
}
if metric_norm(r.view(), metric_weights) <= tol {
return Ok(x);
}
let mut ap = Array1::<f64>::zeros(n);
let mut candidate = Array1::<f64>::zeros(n);
for _ in 0..max_iterations {
matvec(&p, &mut ap);
let pap = metric_dot(&p, &ap, metric_weights);
if pap <= 0.0 || !pap.is_finite() {
if radius.is_finite() {
return Ok(step_to_trust_boundary(&x, &p, radius, metric_weights));
}
return Err(ArrowSchurError::PcgFailed {
reason: "negative curvature in unbounded Schur PCG".to_string(),
});
}
let alpha = rz / pap;
for i in 0..n {
candidate[i] = x[i] + alpha * p[i];
}
if radius.is_finite() && metric_norm(candidate.view(), metric_weights) >= radius {
return Ok(step_to_trust_boundary(&x, &p, radius, metric_weights));
}
x.assign(&candidate);
for i in 0..n {
r[i] -= alpha * ap[i];
}
if metric_norm(r.view(), metric_weights) <= tol {
return Ok(x);
}
z = apply_preconditioner(&r);
let rz_next = metric_dot(&r, &z, metric_weights);
if rz_next <= 0.0 || !rz_next.is_finite() {
return Err(ArrowSchurError::PcgFailed {
reason: "non-positive or non-finite PCG residual".to_string(),
});
}
let beta = rz_next / rz;
for i in 0..n {
p[i] = z[i] + beta * p[i];
}
rz = rz_next;
}
Ok(x)
}
fn step_to_trust_boundary(
x: &Array1<f64>,
p: &Array1<f64>,
radius: f64,
metric_weights: Option<&MetricWeights>,
) -> Array1<f64> {
let pp = metric_dot(p, p, metric_weights);
if pp == 0.0 {
return x.clone();
}
let xp = metric_dot(x, p, metric_weights);
let xx = metric_dot(x, x, metric_weights);
let disc = (xp * xp + pp * (radius * radius - xx)).max(0.0);
let tau = (-xp + disc.sqrt()) / pp;
let mut out = x.clone();
for i in 0..out.len() {
out[i] += tau * p[i];
}
out
}
fn dense_matvec(a: &Array2<f64>, x: &Array1<f64>, out: &mut Array1<f64>) {
let n = a.nrows();
for i in 0..n {
let mut acc = 0.0;
for j in 0..n {
acc += a[[i, j]] * x[j];
}
out[i] = acc;
}
}
fn dot(a: &Array1<f64>, b: &Array1<f64>) -> f64 {
let mut acc = 0.0;
for i in 0..a.len() {
acc += a[i] * b[i];
}
acc
}
fn metric_dot(a: &Array1<f64>, b: &Array1<f64>, metric_weights: Option<&MetricWeights>) -> f64 {
assert_eq!(a.len(), b.len());
match metric_weights {
Some(weights) => {
assert_eq!(weights.len(), a.len());
let mut acc = 0.0;
for i in 0..a.len() {
acc += weights[i] * a[i] * b[i];
}
acc
}
None => dot(a, b),
}
}
fn metric_norm(v: ArrayView1<'_, f64>, metric_weights: Option<&MetricWeights>) -> f64 {
let mut acc = 0.0;
match metric_weights {
Some(weights) => {
assert_eq!(weights.len(), v.len());
for i in 0..v.len() {
acc += weights[i] * v[i] * v[i];
}
}
None => {
for x in v.iter() {
acc += x * x;
}
}
}
acc.sqrt()
}
fn symmetrize_upper_from_lower(a: &mut Array2<f64>) {
let n = a.nrows().min(a.ncols());
for i in 0..n {
for j in 0..i {
let v = 0.5 * (a[[i, j]] + a[[j, i]]);
a[[i, j]] = v;
a[[j, i]] = v;
}
}
}
#[derive(Debug, Clone)]
pub enum ArrowSchurError {
PerRowFactorFailed { row: usize, reason: String },
PerRowFactorIllConditioned { row: usize, kappa_estimate: f64 },
SchurFactorFailed { reason: String },
PcgFailed { reason: String },
AdaptiveCorrectionFailed { reason: String },
}
impl std::fmt::Display for ArrowSchurError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
ArrowSchurError::PerRowFactorFailed { row, reason } => write!(
f,
"arrow-Schur: per-row H_tt^({row}) Cholesky failed: {reason}"
),
ArrowSchurError::PerRowFactorIllConditioned {
row,
kappa_estimate,
} => write!(
f,
"arrow-Schur: per-row H_tt^({row}) Cholesky succeeded but is \
ill-conditioned (kappa_estimate={kappa_estimate:e}); Schur \
reduction would be numerically contaminated"
),
ArrowSchurError::SchurFactorFailed { reason } => {
write!(f, "arrow-Schur: Schur complement Cholesky failed: {reason}")
}
ArrowSchurError::PcgFailed { reason } => {
write!(f, "arrow-Schur: Schur PCG failed: {reason}")
}
ArrowSchurError::AdaptiveCorrectionFailed { reason } => {
write!(
f,
"arrow-Schur: adaptive proximal correction failed: {reason}"
)
}
}
}
}
impl std::error::Error for ArrowSchurError {}
fn cholesky_lower(a: &Array2<f64>) -> Result<Array2<f64>, String> {
let n = a.nrows();
if a.ncols() != n {
return Err(format!("cholesky_lower: non-square {}×{}", n, a.ncols()));
}
if let Some((idx, _)) = a.iter().enumerate().find(|(_, v)| !v.is_finite()) {
return Err(format!(
"cholesky_lower: non-finite entry at linear index {idx}"
));
}
let mut l = Array2::<f64>::zeros((n, n));
for i in 0..n {
for j in 0..=i {
let mut sum = a[[i, j]];
for kk in 0..j {
sum -= l[[i, kk]] * l[[j, kk]];
}
if i == j {
if !sum.is_finite() || sum <= 0.0 {
return Err(format!(
"non-PD pivot {sum} at index {i} (matrix is not positive definite)"
));
}
l[[i, j]] = sum.sqrt();
} else {
l[[i, j]] = sum / l[[j, j]];
}
}
}
Ok(l)
}
fn chol_solve_vector(l: &Array2<f64>, b: &Array1<f64>) -> Array1<f64> {
let n = l.nrows();
let mut y = Array1::<f64>::zeros(n);
for i in 0..n {
let mut sum = b[i];
for kk in 0..i {
sum -= l[[i, kk]] * y[kk];
}
y[i] = sum / l[[i, i]];
}
let mut x = Array1::<f64>::zeros(n);
for i in (0..n).rev() {
let mut sum = y[i];
for kk in (i + 1)..n {
sum -= l[[kk, i]] * x[kk];
}
x[i] = sum / l[[i, i]];
}
x
}
fn chol_solve_matrix(l: &Array2<f64>, b: &Array2<f64>) -> Array2<f64> {
let n = l.nrows();
let m = b.ncols();
let mut out = Array2::<f64>::zeros((n, m));
let mut col = Array1::<f64>::zeros(n);
for cidx in 0..m {
for r in 0..n {
col[r] = b[[r, cidx]];
}
let x = chol_solve_vector(l, &col);
for r in 0..n {
out[[r, cidx]] = x[r];
}
}
out
}
fn lower_triangular_solve_matrix(l: &Array2<f64>, b: &Array2<f64>) -> Array2<f64> {
let n = l.nrows();
let m = b.ncols();
let mut out = Array2::<f64>::zeros((n, m));
for cidx in 0..m {
for i in 0..n {
let mut sum = b[[i, cidx]];
for kk in 0..i {
sum -= l[[i, kk]] * out[[kk, cidx]];
}
out[[i, cidx]] = sum / l[[i, i]];
}
}
out
}
#[cfg(test)]
mod tests {
use super::*;
use ndarray::array;
#[test]
fn arrow_schur_matches_dense_reference_2x2() {
let n = 2;
let d = 2;
let k = 3;
let mut sys = ArrowSchurSystem::new(n, d, k);
sys.rows[0].htt = array![[2.0_f64, 0.1], [0.1, 3.0]];
sys.rows[0].htbeta = array![[1.0_f64, 0.0, 0.5], [0.2, 1.0, 0.0]];
sys.rows[0].gt = array![0.3_f64, -0.2];
sys.rows[1].htt = array![[1.5_f64, -0.1], [-0.1, 2.0]];
sys.rows[1].htbeta = array![[0.1_f64, 0.5, 0.0], [0.0, 0.3, 1.0]];
sys.rows[1].gt = array![-0.1_f64, 0.4];
sys.hbb = array![[4.0_f64, 0.2, 0.0], [0.2, 5.0, 0.1], [0.0, 0.1, 6.0],];
sys.gb = array![0.5_f64, -0.3, 0.2];
let (delta_t, delta_beta) = sys.solve(0.0, 0.0).expect("arrow-schur solve");
let streaming_options = ArrowSolveOptions::direct().with_streaming_chunk_size(Some(1));
let (delta_t_stream, delta_beta_stream) = sys
.solve_with_options(0.0, 0.0, &streaming_options)
.expect("streaming arrow-schur solve");
assert_eq!(delta_beta, delta_beta_stream);
assert_eq!(delta_t, delta_t_stream);
let total = k + n * d;
let mut hjoint = Array2::<f64>::zeros((total, total));
let mut gjoint = Array1::<f64>::zeros(total);
for a in 0..k {
for b in 0..k {
hjoint[[a, b]] = sys.hbb[[a, b]];
}
gjoint[a] = sys.gb[a];
}
for i in 0..n {
let toff = k + i * d;
for a in 0..d {
for b in 0..d {
hjoint[[toff + a, toff + b]] = sys.rows[i].htt[[a, b]];
}
gjoint[toff + a] = sys.rows[i].gt[a];
for a2 in 0..k {
hjoint[[toff + a, a2]] = sys.rows[i].htbeta[[a, a2]];
hjoint[[a2, toff + a]] = sys.rows[i].htbeta[[a, a2]];
}
}
}
let lj = cholesky_lower(&hjoint).expect("dense ref PD");
let neg_g = gjoint.mapv(|v| -v);
let xref = chol_solve_vector(&lj, &neg_g);
for a in 0..k {
assert!(
(xref[a] - delta_beta[a]).abs() < 1e-10,
"β[{a}] mismatch: dense {} vs arrow {}",
xref[a],
delta_beta[a]
);
}
for i in 0..n {
for a in 0..d {
let dense = xref[k + i * d + a];
let arrow = delta_t[i * d + a];
assert!(
(dense - arrow).abs() < 1e-10,
"t[{i},{a}] mismatch: dense {dense} vs arrow {arrow}"
);
}
}
}
fn quartic_counterexample_value(t: f64) -> f64 {
0.25 * t.powi(4) - t * t + 2.0 * t
}
fn quartic_counterexample_system(t: f64) -> ArrowSchurSystem {
let mut sys = ArrowSchurSystem::new(1, 1, 0);
sys.rows[0].gt = array![t.powi(3) - 2.0 * t + 2.0];
sys.rows[0].htt = array![[3.0 * t * t - 2.0]];
sys
}
#[test]
fn proximal_correction_breaks_scalar_newton_cycle() {
let options = ArrowSolveOptions::direct();
let correction = ArrowProximalCorrectionOptions {
initial_ridge: 1e-8,
ridge_growth: 10.0,
max_attempts: 16,
armijo_c1: 1e-4,
gradient_tolerance: 1e-12,
};
let mut t = 0.0_f64;
let mut previous_value = quartic_counterexample_value(t);
for _ in 0..32 {
let sys = quartic_counterexample_system(t);
let accepted = solve_arrow_newton_step_with_proximal_correction(
&sys,
0.0,
0.0,
previous_value,
&options,
&correction,
|delta_t, _delta_beta| quartic_counterexample_value(t + delta_t[0]),
)
.expect("proximal correction should accept a descent step");
assert!(
accepted.trial_objective_value <= previous_value,
"accepted step must not increase the objective"
);
t += accepted.delta_t[0];
previous_value = accepted.trial_objective_value;
}
let final_grad = t.powi(3) - 2.0 * t + 2.0;
assert!(
final_grad.abs() < 1e-7,
"corrected iteration should reach the scalar critical point; t={t}, g={final_grad}"
);
}
#[test]
fn factor_one_row_rejects_barely_pd_block() {
let d = 2;
let k = 2;
let mut row = ArrowRowBlock::new(d, k);
row.htt = array![[1.0_f64, 1.0], [1.0, 1.0 + 1e-14]];
row.htbeta = array![[1.0_f64, 0.0], [0.0, 1.0]];
row.gt = array![0.0_f64, 0.0];
let err = factor_one_row(&row, 0.0, d, 0)
.expect_err("barely-PD H_tt must be rejected by the condition check");
match err {
ArrowSchurError::PerRowFactorIllConditioned {
row: r,
kappa_estimate,
} => {
assert_eq!(r, 0);
assert!(
kappa_estimate > 1e10,
"kappa estimate should reflect the barely-PD block; got {kappa_estimate:e}"
);
}
other => panic!("expected PerRowFactorIllConditioned, got {other:?}"),
}
let mut row_ok = ArrowRowBlock::new(d, k);
row_ok.htt = array![[2.0_f64, 0.1], [0.1, 3.0]];
row_ok.htbeta = array![[1.0_f64, 0.0], [0.0, 1.0]];
row_ok.gt = array![0.0_f64, 0.0];
factor_one_row(&row_ok, 0.0, d, 0)
.expect("well-conditioned block must still factor at ridge_t=0");
}
#[test]
fn lm_escalation_recovers_from_ill_conditioned_row() {
let n = 1;
let d = 2;
let k = 2;
let mut sys = ArrowSchurSystem::new(n, d, k);
sys.rows[0].htt = array![[1.0_f64, 1.0], [1.0, 1.0 + 1e-14]];
sys.rows[0].htbeta = array![[1.0_f64, 0.0], [0.0, 1.0]];
sys.rows[0].gt = array![0.1_f64, -0.2];
sys.hbb = array![[4.0_f64, 0.2], [0.2, 5.0]];
sys.gb = array![0.3_f64, -0.1];
let direct = factor_one_row(&sys.rows[0], 0.0, d, 0);
assert!(matches!(
direct,
Err(ArrowSchurError::PerRowFactorIllConditioned { .. })
));
let options = ArrowSolveOptions::direct();
let (delta_t, delta_beta) = solve_with_lm_escalation_inner(&sys, 0.0, 0.0, &options)
.expect("LM escalation must recover from PerRowFactorIllConditioned");
for v in delta_t.iter().chain(delta_beta.iter()) {
assert!(v.is_finite(), "recovered step must be finite: {v}");
}
}
}