use super::*;
pub fn solve_arrow_newton_step_with_options(
sys: &ArrowSchurSystem,
ridge_t: f64,
ridge_beta: f64,
options: &ArrowSolveOptions,
) -> Result<(Array1<f64>, Array1<f64>, ArrowFactorCache), ArrowSchurError> {
if options.streaming_chunk_size.is_some() {
return Err(ArrowSchurError::SchurFactorFailed {
reason: "streaming Arrow-Schur solve does not materialize the factor cache required by this entry point".to_string(),
});
}
let downdated_owner;
let (sys, ibp_source): (&ArrowSchurSystem, Option<&IbpCrossRowSource>) =
match sys.ibp_cross_row.as_ref() {
Some(source) => {
let mut downdated = sys.clone();
let total_len = downdated.row_offsets[downdated.rows.len()];
let down = source.self_term_downdate(total_len);
let offsets = Arc::clone(&downdated.row_offsets);
for (i, row) in downdated.rows.iter_mut().enumerate() {
let base = offsets[i];
let di = row.htt.nrows();
for j in 0..di {
row.htt[[j, j]] -= down[base + j];
}
}
downdated.refresh_row_hessian_fingerprint();
downdated_owner = downdated;
(&downdated_owner, Some(source))
}
None => (sys, None),
};
let step = solve_arrow_newton_step_artifacts(sys, ridge_t, ridge_beta, options)?;
let backend = CpuBatchedBlockSolver;
let htbeta_estimated_bytes =
estimated_htbeta_bytes(sys.rows.len(), sys.d, sys.k).unwrap_or(usize::MAX);
let htbeta = if let Some(op) = sys.htbeta_matvec.as_ref() {
ArrowHtbetaCache::Matvec {
op: Arc::clone(op),
estimated_bytes: htbeta_estimated_bytes,
}
} else if htbeta_estimated_bytes <= ARROW_FACTOR_CACHE_HTBETA_BUDGET_BYTES {
ArrowHtbetaCache::Dense {
blocks: sys
.rows
.iter()
.map(|r| r.htbeta.clone())
.collect::<Vec<_>>()
.into(),
estimated_bytes: htbeta_estimated_bytes,
}
} else {
ArrowHtbetaCache::Disabled {
estimated_bytes: htbeta_estimated_bytes,
}
};
let htt_factors = step.htt_factors;
let (htt_factors_undamped, gauge_deflated_directions) = if ridge_t == 0.0 {
(
ArrowUndampedFactors::SameAsDamped,
step.gauge_deflated_directions,
)
} else {
let undamped = factor_blocks_for_system(sys, 0.0, options, &backend)?;
(
ArrowUndampedFactors::Owned(undamped.factors),
undamped.gauge_deflated_directions,
)
};
let mut cache = ArrowFactorCache {
htt_factors,
htt_factors_undamped,
schur_factor: step.schur_factor,
joint_hessian_log_det: None,
solver_mode: options.mode,
ridge_t,
ridge_beta,
htbeta,
d: sys.d,
row_dims: Arc::clone(&sys.row_dims),
row_offsets: Arc::clone(&sys.row_offsets),
k: sys.k,
manifold_mode_fingerprint: sys.manifold_mode_fingerprint,
row_hessian_fingerprint: sys.current_row_hessian_fingerprint(),
pcg_diagnostics: step.pcg_diagnostics,
gauge_deflated_directions,
cross_row_woodbury: None,
};
let mut delta_t = step.delta_t;
let mut delta_beta = step.delta_beta;
if let Some(source) = ibp_source {
if let Some(woodbury) = CrossRowWoodbury::build(&cache, source)? {
let h0inv_neg_g_t = delta_t.clone();
woodbury.apply_inverse_correction(
h0inv_neg_g_t.view(),
&source.entries,
&mut delta_t,
&mut delta_beta,
)?;
cache.cross_row_woodbury = Some(woodbury);
}
}
cache.joint_hessian_log_det = cache.compute_undamped_arrow_log_det();
Ok((delta_t, delta_beta, cache))
}
pub(crate) fn estimated_htbeta_bytes(n: usize, d: usize, k: usize) -> Option<usize> {
n.checked_mul(d)?
.checked_mul(k)?
.checked_mul(std::mem::size_of::<f64>())
}
pub fn solve_arrow_newton_step_core(
sys: &ArrowSchurSystem,
ridge_t: f64,
ridge_beta: f64,
options: &ArrowSolveOptions,
) -> Result<(Array1<f64>, Array1<f64>, PcgDiagnostics), ArrowSchurError> {
if let Some(chunk_size) = options.streaming_chunk_size {
let streaming_options = options.with_streaming_solve_precision_default();
let mut streaming = StreamingArrowSchur::from_system(sys, chunk_size);
return streaming
.solve(ridge_t, ridge_beta, &streaming_options)
.map(|(delta_t, delta_beta, _)| (delta_t, delta_beta, PcgDiagnostics::default()));
}
if let Some(device_step) = try_device_arrow_direct(sys, ridge_t, ridge_beta, options) {
return device_step;
}
if let Some(device_options) = maybe_inject_gpu_schur_matvec(sys, ridge_t, ridge_beta, options) {
return solve_arrow_newton_step_artifacts(sys, ridge_t, ridge_beta, &device_options).map(
|step| {
let mut diagnostics = step.pcg_diagnostics;
diagnostics.injected_host_procedural_matvec = true;
(step.delta_t, step.delta_beta, diagnostics)
},
);
}
solve_arrow_newton_step_artifacts(sys, ridge_t, ridge_beta, options)
.map(|step| (step.delta_t, step.delta_beta, step.pcg_diagnostics))
}
pub(crate) fn maybe_inject_gpu_schur_matvec(
sys: &ArrowSchurSystem,
ridge_t: f64,
ridge_beta: f64,
options: &ArrowSolveOptions,
) -> Option<ArrowSolveOptions> {
if options.mode != ArrowSolverMode::InexactPCG || options.gpu_matvec.is_some() {
return None;
}
if !sys.cross_row_penalties.is_empty() || options.streaming_chunk_size.is_some() {
return None;
}
let runtime = crate::gpu::device_runtime::GpuRuntime::global()?;
let cg_iters = options
.pcg
.max_iterations
.min(options.trust_region.max_iterations);
if !runtime
.policy()
.reduced_schur_matvec_should_offload(sys.rows.len(), sys.k, sys.d, cg_iters)
{
return None;
}
let matvec =
crate::gpu::kernels::arrow_schur::gpu_schur_matvec_backend(sys, ridge_t, ridge_beta)
.ok()?;
let mut device_options = options.clone();
device_options.gpu_matvec = Some(matvec);
Some(device_options)
}
pub(crate) fn try_device_arrow_direct(
sys: &ArrowSchurSystem,
ridge_t: f64,
ridge_beta: f64,
options: &ArrowSolveOptions,
) -> Option<Result<(Array1<f64>, Array1<f64>, PcgDiagnostics), ArrowSchurError>> {
if options.mode != ArrowSolverMode::Direct {
return None;
}
if !sys.cross_row_penalties.is_empty()
|| options.streaming_chunk_size.is_some()
|| sys.hbb_matvec.is_some()
|| sys.htbeta_matvec.is_some()
{
return None;
}
let runtime = crate::gpu::device_runtime::GpuRuntime::global()?;
let admitted = runtime
.policy()
.dense_hessian_work_target_is_gpu(sys.rows.len(), sys.k);
if !admitted {
return None;
}
match crate::gpu::kernels::arrow_schur::solve_arrow_newton_step(sys, ridge_t, ridge_beta) {
Ok(solution) => {
let diagnostics = PcgDiagnostics {
used_device_arrow: true,
..PcgDiagnostics::default()
};
Some(Ok((solution.delta_t, solution.delta_beta, diagnostics)))
}
Err(crate::gpu::kernels::arrow_schur::ArrowSchurGpuFailure::RidgeBumpRequired {
row,
bump,
}) => Some(Err(ArrowSchurError::PerRowFactorFailed {
row,
reason: format!("device per-row block non-PD; suggested ridge bump {bump:e}"),
})),
Err(crate::gpu::kernels::arrow_schur::ArrowSchurGpuFailure::SchurFactorFailed {
reason,
}) => Some(Err(ArrowSchurError::SchurFactorFailed { reason })),
Err(_) => None,
}
}
pub fn solve_with_lm_escalation_inner(
sys: &ArrowSchurSystem,
ridge_t: f64,
ridge_beta: f64,
options: &ArrowSolveOptions,
) -> Result<(Array1<f64>, Array1<f64>, PcgDiagnostics), ArrowSchurError> {
let mut proximal_ridge = 0.0_f64;
let mut escalations: usize = 0;
let mut last_err: Option<ArrowSchurError> = None;
for attempt in 0..=DEFAULT_PROXIMAL_MAX_ATTEMPTS {
let damped_ridge_t = ridge_t + proximal_ridge;
let damped_ridge_beta = ridge_beta + proximal_ridge;
match solve_arrow_newton_step_core(sys, damped_ridge_t, damped_ridge_beta, options) {
Ok((delta_t, delta_beta, mut pcg_diagnostics)) => {
pcg_diagnostics.ridge_escalations = escalations;
return Ok((delta_t, delta_beta, pcg_diagnostics));
}
Err(err) => {
let recoverable = matches!(
err,
ArrowSchurError::PerRowFactorFailed { .. }
| ArrowSchurError::PerRowFactorIllConditioned { .. }
| ArrowSchurError::SchurFactorFailed { .. }
| ArrowSchurError::PcgFailed { .. }
| ArrowSchurError::UnboundedNegativeCurvature { .. }
);
last_err = Some(err);
if !recoverable {
break;
}
if attempt == DEFAULT_PROXIMAL_MAX_ATTEMPTS {
break;
}
proximal_ridge = if proximal_ridge == 0.0 {
DEFAULT_PROXIMAL_INITIAL_RIDGE
} else {
proximal_ridge * DEFAULT_PROXIMAL_RIDGE_GROWTH
};
escalations += 1;
}
}
}
Err(last_err.expect("escalation loop set last_err on failure"))
}
pub fn solve_arrow_newton_step_with_proximal_correction<F>(
sys: &ArrowSchurSystem,
base_ridge_t: f64,
base_ridge_beta: f64,
current_objective_value: f64,
options: &ArrowSolveOptions,
correction: &ArrowProximalCorrectionOptions,
mut trial_objective: F,
) -> Result<ArrowAcceptedProximalStep, ArrowSchurError>
where
F: for<'a, 'b> FnMut(ArrayView1<'a, f64>, ArrayView1<'b, f64>) -> f64,
{
if !current_objective_value.is_finite() {
return Err(ArrowSchurError::AdaptiveCorrectionFailed {
reason: "current objective is not finite".to_string(),
});
}
if !(correction.ridge_growth.is_finite() && correction.ridge_growth > 1.0) {
return Err(ArrowSchurError::AdaptiveCorrectionFailed {
reason: format!(
"ridge_growth must be finite and > 1; got {}",
correction.ridge_growth
),
});
}
if !(correction.armijo_c1.is_finite()
&& correction.armijo_c1 > 0.0
&& correction.armijo_c1 < 1.0)
{
return Err(ArrowSchurError::AdaptiveCorrectionFailed {
reason: format!("armijo_c1 must be in (0, 1); got {}", correction.armijo_c1),
});
}
let grad_norm = arrow_gradient_norm(sys);
if grad_norm <= correction.gradient_tolerance.max(0.0) {
return Ok(ArrowAcceptedProximalStep {
delta_t: Array1::<f64>::zeros(sys.row_offsets[sys.rows.len()]),
delta_beta: Array1::<f64>::zeros(sys.k),
ridge_t: base_ridge_t,
ridge_beta: base_ridge_beta,
proximal_ridge: 0.0,
objective_value: current_objective_value,
trial_objective_value: current_objective_value,
gradient_dot_step: 0.0,
attempts: 0,
});
}
let objective_resolution =
correction.convergence_objective_rel_tol.max(0.0) * (current_objective_value.abs() + 1.0);
let mut proximal_ridge = correction.initial_ridge.max(0.0);
let mut last_reason = String::from("no attempts were made");
let mut best_decrease: Option<(Array1<f64>, Array1<f64>, f64, f64, f64, f64, f64)> = None;
let mut smallest_increase = f64::INFINITY;
for attempt in 0..correction.max_attempts {
let ridge_t = base_ridge_t + proximal_ridge;
let ridge_beta = base_ridge_beta + proximal_ridge;
match solve_arrow_newton_step_core(sys, ridge_t, ridge_beta, options) {
Ok((delta_t, delta_beta, _diag)) => {
let g_dot_p = arrow_gradient_dot_step(sys, delta_t.view(), delta_beta.view());
if !(g_dot_p.is_finite() && g_dot_p < 0.0) {
last_reason =
format!("candidate was not a finite descent direction: g·p={g_dot_p}");
} else {
let trial_value = trial_objective(delta_t.view(), delta_beta.view());
let armijo_bound = current_objective_value + correction.armijo_c1 * g_dot_p;
if trial_value.is_finite() && trial_value <= armijo_bound {
return Ok(ArrowAcceptedProximalStep {
delta_t,
delta_beta,
ridge_t,
ridge_beta,
proximal_ridge,
objective_value: current_objective_value,
trial_objective_value: trial_value,
gradient_dot_step: g_dot_p,
attempts: attempt + 1,
});
}
if trial_value.is_finite() {
let delta_obj = trial_value - current_objective_value;
if delta_obj < -objective_resolution {
let improves = best_decrease.as_ref().is_none_or(
|(_, _, best_value, _, _, _, _)| trial_value < *best_value,
);
if improves {
best_decrease = Some((
delta_t.clone(),
delta_beta.clone(),
trial_value,
g_dot_p,
ridge_t,
ridge_beta,
proximal_ridge,
));
}
} else if delta_obj < smallest_increase {
smallest_increase = delta_obj;
}
}
last_reason = {
let step_norm = (delta_t.iter().map(|v| v * v).sum::<f64>()
+ delta_beta.iter().map(|v| v * v).sum::<f64>())
.sqrt();
format!(
"Armijo rejected trial objective {trial_value}; bound {armijo_bound}; \
|g|={grad_norm:.4e} g.p={g_dot_p:.4e} |step|={step_norm:.4e} ridge={proximal_ridge:.3e}"
)
};
}
}
Err(err) => {
last_reason = err.to_string();
}
}
proximal_ridge = next_proximal_ridge(proximal_ridge, correction.ridge_growth);
}
if let Some((delta_t, delta_beta, trial_value, g_dot_p, ridge_t, ridge_beta, best_ridge)) =
best_decrease
{
let reapplied = trial_objective(delta_t.view(), delta_beta.view());
let final_value = if reapplied.is_finite() {
reapplied
} else {
trial_value
};
return Ok(ArrowAcceptedProximalStep {
delta_t,
delta_beta,
ridge_t,
ridge_beta,
proximal_ridge: best_ridge,
objective_value: current_objective_value,
trial_objective_value: final_value,
gradient_dot_step: g_dot_p,
attempts: correction.max_attempts,
});
}
if smallest_increase.is_finite() && smallest_increase <= objective_resolution {
return Ok(ArrowAcceptedProximalStep {
delta_t: Array1::<f64>::zeros(sys.row_offsets[sys.rows.len()]),
delta_beta: Array1::<f64>::zeros(sys.k),
ridge_t: base_ridge_t,
ridge_beta: base_ridge_beta,
proximal_ridge: 0.0,
objective_value: current_objective_value,
trial_objective_value: current_objective_value,
gradient_dot_step: 0.0,
attempts: correction.max_attempts,
});
}
Err(ArrowSchurError::AdaptiveCorrectionFailed {
reason: format!(
"failed after {} attempts; last rejection: {last_reason}",
correction.max_attempts
),
})
}
pub fn arrow_damped_quadratic_model_reduction(
sys: &ArrowSchurSystem,
delta_t: ArrayView1<'_, f64>,
delta_beta: ArrayView1<'_, f64>,
ridge_t: f64,
ridge_beta: f64,
) -> Result<f64, ArrowSchurError> {
let total_len = sys.row_offsets[sys.rows.len()];
assert_eq!(delta_t.len(), total_len);
assert_eq!(delta_beta.len(), sys.k);
let mut lin = sys.gb.dot(&delta_beta);
let mut quad = ridge_beta * delta_beta.dot(&delta_beta);
let mut hbb_delta = Array1::<f64>::zeros(sys.k);
{
let x_slice = delta_beta
.as_slice()
.expect("delta_beta must be contiguous");
let y_slice = hbb_delta
.as_slice_mut()
.expect("hbb_delta must be contiguous");
sys.penalty_matvec_add(x_slice, y_slice);
}
quad += delta_beta.dot(&hbb_delta);
let mut htbeta_x = Array1::<f64>::zeros(sys.d);
for (i, row) in sys.rows.iter().enumerate() {
let di = sys.row_dims[i];
let row_base = sys.row_offsets[i];
let mut htbeta_x_i = htbeta_x.slice_mut(ndarray::s![..di]).to_owned();
htbeta_x_i.fill(0.0);
sys_htbeta_apply_row(sys, i, row, delta_beta, &mut htbeta_x_i);
for c in 0..di {
let dt_c = delta_t[row_base + c];
lin += row.gt[c] * dt_c;
quad += ridge_t * dt_c * dt_c;
for r in 0..di {
quad += dt_c * row.htt[[c, r]] * delta_t[row_base + r];
}
quad += 2.0 * dt_c * htbeta_x_i[c];
}
}
Ok(-(lin + 0.5 * quad))
}
pub fn arrow_bare_quadratic_model_reduction(
sys: &ArrowSchurSystem,
delta_t: ArrayView1<'_, f64>,
delta_beta: ArrayView1<'_, f64>,
ridge_t: f64,
ridge_beta: f64,
) -> Result<f64, ArrowSchurError> {
let damped =
arrow_damped_quadratic_model_reduction(sys, delta_t, delta_beta, ridge_t, ridge_beta)?;
let ridge_beta_contrib = 0.5 * ridge_beta * delta_beta.dot(&delta_beta);
let ridge_t_contrib = {
let mut acc = 0.0_f64;
for v in delta_t.iter() {
acc += v * v;
}
0.5 * ridge_t * acc
};
Ok(damped + ridge_beta_contrib + ridge_t_contrib)
}
pub(crate) fn next_proximal_ridge(current: f64, growth: f64) -> f64 {
if current > 0.0 {
current * growth
} else {
DEFAULT_PROXIMAL_INITIAL_RIDGE
}
}
pub(crate) fn arrow_gradient_norm(sys: &ArrowSchurSystem) -> f64 {
let mut sum = 0.0;
for row in sys.rows.iter() {
for &v in row.gt.iter() {
sum += v * v;
}
}
for &v in sys.gb.iter() {
sum += v * v;
}
sum.sqrt()
}
pub(crate) fn arrow_gradient_dot_step(
sys: &ArrowSchurSystem,
delta_t: ArrayView1<'_, f64>,
delta_beta: ArrayView1<'_, f64>,
) -> f64 {
assert_eq!(delta_t.len(), sys.row_offsets[sys.rows.len()]);
assert_eq!(delta_beta.len(), sys.k);
let mut out = 0.0;
for (i, row) in sys.rows.iter().enumerate() {
let di = sys.row_dims[i];
let row_base = sys.row_offsets[i];
for c in 0..di {
out += row.gt[c] * delta_t[row_base + c];
}
}
for a in 0..sys.k {
out += sys.gb[a] * delta_beta[a];
}
out
}
pub(crate) struct ArrowNewtonStepArtifacts {
pub(crate) delta_t: Array1<f64>,
pub(crate) delta_beta: Array1<f64>,
pub(crate) htt_factors: ArrowFactorSlab,
pub(crate) schur_factor: Option<Array2<f64>>,
pub(crate) pcg_diagnostics: PcgDiagnostics,
pub(crate) gauge_deflated_directions: usize,
}
pub(crate) struct ArrowBlockFactorization {
pub(crate) factors: ArrowFactorSlab,
pub(crate) gauge_deflated_directions: usize,
}
pub(crate) fn factor_blocks_for_system<B: BatchedBlockSolver>(
sys: &ArrowSchurSystem,
ridge_t: f64,
options: &ArrowSolveOptions,
backend: &B,
) -> Result<ArrowBlockFactorization, ArrowSchurError> {
let Some(deflation) = sys.row_gauge_deflation.as_ref() else {
return Ok(ArrowBlockFactorization {
factors: backend.factor_blocks(
&sys.rows,
ridge_t,
sys.d,
options.tolerate_ill_conditioning,
)?,
gauge_deflated_directions: 0,
});
};
let mut blocks = Vec::with_capacity(sys.rows.len());
let mut count = 0usize;
for (row_idx, row) in sys.rows.iter().enumerate() {
let result = factor_one_row_result(
row,
ridge_t,
sys.row_dims[row_idx],
row_idx,
options.tolerate_ill_conditioning,
deflation.row(row_idx),
true,
)?;
count += result.gauge_deflated_directions;
blocks.push(result.factor);
}
Ok(ArrowBlockFactorization {
factors: ArrowFactorSlab::from_blocks(blocks),
gauge_deflated_directions: count,
})
}
pub(crate) enum MixedPrecisionAttempt {
Certified {
delta_t: Array1<f64>,
delta_beta: Array1<f64>,
schur_factor: Array2<f64>,
refinement_steps: usize,
},
Fallback {
reason: String,
},
}
pub(crate) fn back_substitute_delta_t<B: BatchedBlockSolver + Sync>(
sys: &ArrowSchurSystem,
htt_factors: &ArrowFactorSlab,
delta_beta: ArrayView1<'_, f64>,
backend: &B,
) -> Array1<f64> {
let n = sys.rows.len();
let total_dt_len = sys.row_offsets[n];
let mut delta_t = Array1::<f64>::zeros(total_dt_len);
let parallel = n >= SCHUR_MATVEC_PARALLEL_ROW_MIN && rayon::current_thread_index().is_none();
let solve_row = |i: usize, out: &mut [f64]| {
let di = sys.row_dims[i];
assert!(
sys.rows[i].gt.len() == di,
"back_substitute_delta_t: row {i} gt len {} != row dim {di}",
sys.rows[i].gt.len()
);
let mut htbeta_slice = Array1::<f64>::zeros(di);
sys_htbeta_apply_row(sys, i, &sys.rows[i], delta_beta, &mut htbeta_slice);
let mut rhs = Array1::<f64>::zeros(di);
for c in 0..di {
rhs[c] = sys.rows[i].gt[c] + htbeta_slice[c];
}
let dt_i = backend.solve_block_vector(htt_factors.factor(i), rhs.view());
for c in 0..di {
out[c] = -dt_i[c];
}
};
if parallel {
use rayon::prelude::*;
const CHUNK: usize = 64;
let row_offsets = &sys.row_offsets;
let dt_slice = delta_t.as_slice_mut().expect("delta_t contiguous");
let n_chunks = n.div_ceil(CHUNK);
let mut remaining = dt_slice;
let mut segments: Vec<(usize, &mut [f64])> = Vec::with_capacity(n_chunks);
let mut prev_end = 0usize;
for chunk in 0..n_chunks {
let start = chunk * CHUNK;
let end = (start + CHUNK).min(n);
let seg_len = row_offsets[end] - row_offsets[start];
assert!(
prev_end == row_offsets[start],
"back_substitute_delta_t: non-contiguous row segment at chunk start {start} \
(prev_end={prev_end}, row_offset={})",
row_offsets[start]
);
let (seg, rest) = remaining.split_at_mut(seg_len);
remaining = rest;
segments.push((start, seg));
prev_end = row_offsets[end];
}
segments.into_par_iter().for_each(|(start, seg)| {
let end = (start + CHUNK).min(n);
let mut local = 0usize;
for i in start..end {
let di = sys.row_dims[i];
solve_row(i, &mut seg[local..local + di]);
local += di;
}
});
} else {
for i in 0..n {
let row_base = sys.row_offsets[i];
let di = sys.row_dims[i];
solve_row(
i,
delta_t
.as_slice_mut()
.expect("delta_t contiguous")
.get_mut(row_base..row_base + di)
.expect("row segment in bounds"),
);
}
}
delta_t
}
pub(crate) fn try_mixed_precision_arrow_solve(
sys: &ArrowSchurSystem,
ridge_t: f64,
ridge_beta: f64,
htt_factors: &ArrowFactorSlab,
schur: &Array2<f64>,
options: &ArrowSolveOptions,
) -> Result<Option<MixedPrecisionAttempt>, ArrowSchurError> {
let ArrowSolvePrecisionPolicy::CertifiedMixed {
max_refinement_steps,
residual_relative_tolerance,
kappa_unit_roundoff_margin,
} = options.solve_precision
else {
return Ok(None);
};
if options.trust_region.radius.is_finite() {
return Ok(Some(MixedPrecisionAttempt::Fallback {
reason: "trust-region-truncated dense solves are not certified by the mixed-precision refinement path".to_string(),
}));
}
let schur_factor =
cholesky_lower(schur).map_err(|e| ArrowSchurError::SchurFactorFailed { reason: e })?;
if !options.tolerate_ill_conditioning {
let schur_kappa = cholesky_factor_kappa_estimate(&schur_factor);
if !schur_kappa.is_finite() || schur_kappa > safe_spd_kappa_max(schur.nrows()) {
return Err(ArrowSchurError::SchurFactorFailed {
reason: format!(
"reduced Schur complement Cholesky succeeded but is ill-conditioned \
(kappa_estimate={schur_kappa:e}); accumulated per-row \
(H_tt)^-1 contamination would yield an inaccurate delta_beta"
),
});
}
}
if let Some(reason) =
mixed_precision_kappa_gate_failure(htt_factors, &schur_factor, kappa_unit_roundoff_margin)
{
return Ok(Some(MixedPrecisionAttempt::Fallback { reason }));
}
let row_factors_f32 = arrow_factor_slab_to_f32(htt_factors);
let schur_factor_f32 = schur_factor.mapv(|v| v as f32);
let (rhs_t, rhs_beta) = arrow_rhs(sys);
let mut x = solve_arrow_system_f32(
sys,
&row_factors_f32,
&schur_factor_f32,
rhs_t.view(),
rhs_beta.view(),
)?;
let certificate_tol = residual_relative_tolerance
.max(MIXED_PRECISION_CERTIFICATE_EPSILON_MULTIPLIER * f64::EPSILON);
for refinement_steps in 0..=max_refinement_steps {
let (res_t, res_beta) = arrow_residual(
sys,
ridge_t,
ridge_beta,
x.0.view(),
x.1.view(),
rhs_t.view(),
rhs_beta.view(),
);
let certificate = arrow_backward_error_certificate(
sys,
ridge_t,
ridge_beta,
x.0.view(),
x.1.view(),
rhs_t.view(),
rhs_beta.view(),
res_t.view(),
res_beta.view(),
)?;
if certificate <= certificate_tol {
return Ok(Some(MixedPrecisionAttempt::Certified {
delta_t: x.0,
delta_beta: x.1,
schur_factor,
refinement_steps,
}));
}
if refinement_steps == max_refinement_steps {
return Ok(Some(MixedPrecisionAttempt::Fallback {
reason: format!(
"f64 residual certificate did not converge after {max_refinement_steps} refinement steps \
(backward_error={certificate:e}, tolerance={certificate_tol:e})"
),
}));
}
let correction = solve_arrow_system_f32(
sys,
&row_factors_f32,
&schur_factor_f32,
res_t.view(),
res_beta.view(),
)?;
if !correction
.0
.iter()
.chain(correction.1.iter())
.all(|v| v.is_finite())
{
return Ok(Some(MixedPrecisionAttempt::Fallback {
reason: "f32 refinement correction produced a non-finite value".to_string(),
}));
}
for i in 0..x.0.len() {
x.0[i] += correction.0[i];
}
for i in 0..x.1.len() {
x.1[i] += correction.1[i];
}
}
Ok(Some(MixedPrecisionAttempt::Fallback {
reason: "mixed refinement loop exhausted without certification".to_string(),
}))
}
pub(crate) fn mixed_precision_kappa_gate_failure(
htt_factors: &ArrowFactorSlab,
schur_factor: &Array2<f64>,
margin: f64,
) -> Option<String> {
let mut max_kappa = cholesky_factor_kappa_estimate(schur_factor);
let mut min_pivot = lower_cholesky_min_pivot(schur_factor.view());
let mut max_pivot = lower_cholesky_max_pivot(schur_factor.view());
for factor in htt_factors.iter() {
let owned = factor.to_owned();
max_kappa = max_kappa.max(cholesky_factor_kappa_estimate(&owned));
if let Some(pivot) = lower_cholesky_min_pivot(owned.view()) {
min_pivot = Some(match min_pivot {
Some(current) => current.min(pivot),
None => pivot,
});
}
if let Some(pivot) = lower_cholesky_max_pivot(owned.view()) {
max_pivot = Some(match max_pivot {
Some(current) => current.max(pivot),
None => pivot,
});
}
}
if let (Some(min_pivot), Some(max_pivot)) = (min_pivot, max_pivot) {
if min_pivot > 0.0 && max_pivot.is_finite() {
max_kappa = max_kappa.max(max_pivot / min_pivot);
} else {
max_kappa = f64::INFINITY;
}
}
let kappa_u = max_kappa * F32_UNIT_ROUNDOFF;
let threshold = margin
.min(MIXED_PRECISION_KAPPA_MARGIN_CEILING)
.max(F32_UNIT_ROUNDOFF);
if !(max_kappa.is_finite() && kappa_u < threshold) {
Some(format!(
"kappa gate refused f32 refinement: kappa_estimate={max_kappa:e}, \
kappa*u_f32={kappa_u:e}, required < {threshold:e}"
))
} else {
None
}
}
pub(crate) fn arrow_factor_slab_to_f32(htt_factors: &ArrowFactorSlab) -> Vec<Array2<f32>> {
htt_factors
.iter()
.map(|factor| factor.mapv(|v| v as f32))
.collect()
}
pub(crate) fn arrow_rhs(sys: &ArrowSchurSystem) -> (Array1<f64>, Array1<f64>) {
let n = sys.rows.len();
let mut rhs_t = Array1::<f64>::zeros(sys.row_offsets[n]);
for i in 0..n {
let di = sys.row_dims[i];
let base = sys.row_offsets[i];
for c in 0..di {
rhs_t[base + c] = -sys.rows[i].gt[c];
}
}
let mut rhs_beta = Array1::<f64>::zeros(sys.k);
for c in 0..sys.k {
rhs_beta[c] = -sys.gb[c];
}
(rhs_t, rhs_beta)
}
pub(crate) fn solve_arrow_system_f32(
sys: &ArrowSchurSystem,
row_factors: &[Array2<f32>],
schur_factor: &Array2<f32>,
rhs_t: ArrayView1<'_, f64>,
rhs_beta: ArrayView1<'_, f64>,
) -> Result<(Array1<f64>, Array1<f64>), ArrowSchurError> {
let n = sys.rows.len();
let mut y_rows = Vec::<Array1<f32>>::with_capacity(n);
let mut reduced_beta = rhs_beta.mapv(|v| v as f32);
for i in 0..n {
let di = sys.row_dims[i];
let base = sys.row_offsets[i];
let rhs_i = rhs_t.slice(ndarray::s![base..base + di]).mapv(|v| v as f32);
let y_i = cholesky_solve_lower_f32(&row_factors[i], &rhs_i);
let htbeta = sys_htbeta_materialize_row(sys, i, &sys.rows[i])?.mapv(|v| v as f32);
for beta_col in 0..sys.k {
let mut acc = 0.0_f32;
for row_axis in 0..di {
acc += htbeta[[row_axis, beta_col]] * y_i[row_axis];
}
reduced_beta[beta_col] -= acc;
}
y_rows.push(y_i);
}
let x_beta_f32 = cholesky_solve_lower_f32(schur_factor, &reduced_beta);
let mut x_t = Array1::<f64>::zeros(sys.row_offsets[n]);
for i in 0..n {
let di = sys.row_dims[i];
let base = sys.row_offsets[i];
let htbeta = sys_htbeta_materialize_row(sys, i, &sys.rows[i])?.mapv(|v| v as f32);
let mut cross = Array1::<f32>::zeros(di);
for row_axis in 0..di {
let mut acc = 0.0_f32;
for beta_col in 0..sys.k {
acc += htbeta[[row_axis, beta_col]] * x_beta_f32[beta_col];
}
cross[row_axis] = acc;
}
let correction = cholesky_solve_lower_f32(&row_factors[i], &cross);
for row_axis in 0..di {
x_t[base + row_axis] = (y_rows[i][row_axis] - correction[row_axis]) as f64;
}
}
let x_beta = x_beta_f32.mapv(|v| v as f64);
Ok((x_t, x_beta))
}
pub(crate) fn cholesky_solve_lower_f32(l: &Array2<f32>, b: &Array1<f32>) -> Array1<f32> {
let n = l.nrows();
assert!(
(0..n).all(|i| l[[i, i]].is_finite() && l[[i, i]].abs() >= f32::MIN_POSITIVE),
"cholesky_solve_lower_f32: factor diagonal must be finite and non-subnormal"
);
let mut y = Array1::<f32>::zeros(n);
for i in 0..n {
let mut sum = b[i];
for j in 0..i {
sum -= l[[i, j]] * y[j];
}
y[i] = sum / l[[i, i]];
}
let mut x = Array1::<f32>::zeros(n);
for i in (0..n).rev() {
let mut sum = y[i];
for j in (i + 1)..n {
sum -= l[[j, i]] * x[j];
}
x[i] = sum / l[[i, i]];
}
x
}
pub(crate) fn arrow_residual(
sys: &ArrowSchurSystem,
ridge_t: f64,
ridge_beta: f64,
x_t: ArrayView1<'_, f64>,
x_beta: ArrayView1<'_, f64>,
rhs_t: ArrayView1<'_, f64>,
rhs_beta: ArrayView1<'_, f64>,
) -> (Array1<f64>, Array1<f64>) {
let (ax_t, ax_beta) = arrow_operator_apply(sys, ridge_t, ridge_beta, x_t, x_beta);
let mut res_t = rhs_t.to_owned();
let mut res_beta = rhs_beta.to_owned();
for i in 0..res_t.len() {
res_t[i] -= ax_t[i];
}
for i in 0..res_beta.len() {
res_beta[i] -= ax_beta[i];
}
(res_t, res_beta)
}
pub(crate) fn arrow_operator_apply(
sys: &ArrowSchurSystem,
ridge_t: f64,
ridge_beta: f64,
x_t: ArrayView1<'_, f64>,
x_beta: ArrayView1<'_, f64>,
) -> (Array1<f64>, Array1<f64>) {
let n = sys.rows.len();
let mut y_t = Array1::<f64>::zeros(sys.row_offsets[n]);
let mut y_beta = Array1::<f64>::zeros(sys.k);
{
let x_slice = x_beta.as_slice().expect("x_beta contiguous");
let y_slice = y_beta.as_slice_mut().expect("y_beta contiguous");
sys.penalty_matvec_add(x_slice, y_slice);
}
for beta_col in 0..sys.k {
y_beta[beta_col] += ridge_beta * x_beta[beta_col];
}
let parallel = n >= SCHUR_MATVEC_PARALLEL_ROW_MIN && rayon::current_thread_index().is_none();
if parallel {
use rayon::prelude::*;
const CHUNK: usize = 64;
let chunks: Vec<(usize, Vec<f64>, Array1<f64>)> = (0..n)
.into_par_iter()
.chunks(CHUNK)
.map(|idxs| {
let first = idxs[0];
let last = idxs[idxs.len() - 1];
let seg_start = sys.row_offsets[first];
let seg_end = sys.row_offsets[last] + sys.row_dims[last];
let mut seg = vec![0.0_f64; seg_end - seg_start];
let mut acc = Array1::<f64>::zeros(sys.k);
for i in idxs {
cross_row_matvec_row_into(
sys, ridge_t, i, x_t, x_beta, seg_start, &mut seg, &mut acc,
);
}
(seg_start, seg, acc)
})
.collect();
for (seg_start, seg, acc) in &chunks {
for (o, v) in seg.iter().enumerate() {
y_t[seg_start + o] = *v;
}
for j in 0..sys.k {
y_beta[j] += acc[j];
}
}
} else {
let y_t_slice = y_t.as_slice_mut().expect("y_t contiguous");
for i in 0..n {
cross_row_matvec_row_into(sys, ridge_t, i, x_t, x_beta, 0, y_t_slice, &mut y_beta);
}
}
(y_t, y_beta)
}
pub(crate) fn arrow_backward_error_certificate(
sys: &ArrowSchurSystem,
ridge_t: f64,
ridge_beta: f64,
x_t: ArrayView1<'_, f64>,
x_beta: ArrayView1<'_, f64>,
rhs_t: ArrayView1<'_, f64>,
rhs_beta: ArrayView1<'_, f64>,
res_t: ArrayView1<'_, f64>,
res_beta: ArrayView1<'_, f64>,
) -> Result<f64, ArrowSchurError> {
let residual_norm = infinity_norm_pair(res_t, res_beta);
let operator_norm = arrow_operator_infinity_norm(sys, ridge_t, ridge_beta)?;
let solution_norm = infinity_norm_pair(x_t, x_beta);
let rhs_norm = infinity_norm_pair(rhs_t, rhs_beta);
let denom = operator_norm * solution_norm + rhs_norm;
if denom > 0.0 {
Ok(residual_norm / denom)
} else {
Ok(residual_norm)
}
}
pub(crate) fn infinity_norm_pair(lhs: ArrayView1<'_, f64>, rhs: ArrayView1<'_, f64>) -> f64 {
let mut out = 0.0_f64;
for &v in lhs.iter().chain(rhs.iter()) {
out = out.max(v.abs());
}
out
}
pub(crate) fn arrow_operator_infinity_norm(
sys: &ArrowSchurSystem,
ridge_t: f64,
ridge_beta: f64,
) -> Result<f64, ArrowSchurError> {
let mut out = 0.0_f64;
let mut beta_cross_abs = vec![0.0_f64; sys.k];
for i in 0..sys.rows.len() {
let di = sys.row_dims[i];
let row = &sys.rows[i];
let htbeta = sys_htbeta_materialize_row(sys, i, row)?;
for a in 0..di {
let mut row_sum = 0.0_f64;
for b in 0..di {
row_sum += row.htt[[a, b]].abs();
}
row_sum += ridge_t;
for beta_col in 0..sys.k {
let v = htbeta[[a, beta_col]].abs();
row_sum += v;
beta_cross_abs[beta_col] += v;
}
out = out.max(row_sum);
}
}
let hbb = sys.effective_penalty_op().to_dense();
for beta_row in 0..sys.k {
let mut row_sum = beta_cross_abs[beta_row] + ridge_beta;
for beta_col in 0..sys.k {
row_sum += hbb[[beta_row, beta_col]].abs();
}
out = out.max(row_sum);
}
Ok(out)
}
pub(crate) fn solve_arrow_newton_step_artifacts(
sys: &ArrowSchurSystem,
ridge_t: f64,
ridge_beta: f64,
options: &ArrowSolveOptions,
) -> Result<ArrowNewtonStepArtifacts, ArrowSchurError> {
if !sys.cross_row_penalties.is_empty() {
return solve_arrow_newton_step_cross_row(sys, ridge_t, ridge_beta, options);
}
if let Some(chunk_size) = options.streaming_chunk_size {
let mut streaming = StreamingArrowSchur::from_system(sys, chunk_size);
let (delta_t, delta_beta, schur_factor) = streaming.solve(ridge_t, ridge_beta, options)?;
return Ok(ArrowNewtonStepArtifacts {
delta_t,
delta_beta,
htt_factors: ArrowFactorSlab::from_blocks(Vec::new()),
schur_factor,
pcg_diagnostics: PcgDiagnostics::default(),
gauge_deflated_directions: 0,
});
}
let backend = CpuBatchedBlockSolver;
let block_factorization = factor_blocks_for_system(sys, ridge_t, options, &backend)?;
let htt_factors = block_factorization.factors;
let gauge_deflated_directions = block_factorization.gauge_deflated_directions;
let rhs_beta = reduced_rhs_beta(sys, &htt_factors, &backend);
let trust_metric_weights = None;
let mut mixed_precision_status = MixedPrecisionStatus::Off;
let (delta_beta, schur_factor, mut pcg_diagnostics) = match options.mode {
ArrowSolverMode::Direct => {
let schur = build_dense_schur_direct(sys, &htt_factors, ridge_beta, &backend)?;
if let Some(attempt) = try_mixed_precision_arrow_solve(
sys,
ridge_t,
ridge_beta,
&htt_factors,
&schur,
options,
)? {
match attempt {
MixedPrecisionAttempt::Certified {
delta_t,
delta_beta,
schur_factor,
refinement_steps,
} => {
let mut pcg_diagnostics = PcgDiagnostics::default();
pcg_diagnostics.mixed_precision_status =
MixedPrecisionStatus::Certified { refinement_steps };
return Ok(ArrowNewtonStepArtifacts {
delta_t,
delta_beta,
htt_factors,
schur_factor: Some(schur_factor),
pcg_diagnostics,
gauge_deflated_directions,
});
}
MixedPrecisionAttempt::Fallback { reason } => {
log::info!("arrow-Schur mixed precision fallback to f64: {reason}");
mixed_precision_status = MixedPrecisionStatus::F64Fallback;
}
}
}
let (db, sf, diag) =
solve_dense_reduced_system(&schur, &rhs_beta, options, trust_metric_weights)?;
(db, sf, diag)
}
ArrowSolverMode::SqrtBA => {
let schur = build_dense_schur_sqrt_ba(sys, &htt_factors, ridge_beta, &backend)?;
if let Some(attempt) = try_mixed_precision_arrow_solve(
sys,
ridge_t,
ridge_beta,
&htt_factors,
&schur,
options,
)? {
match attempt {
MixedPrecisionAttempt::Certified {
delta_t,
delta_beta,
schur_factor,
refinement_steps,
} => {
let mut pcg_diagnostics = PcgDiagnostics::default();
pcg_diagnostics.mixed_precision_status =
MixedPrecisionStatus::Certified { refinement_steps };
return Ok(ArrowNewtonStepArtifacts {
delta_t,
delta_beta,
htt_factors,
schur_factor: Some(schur_factor),
pcg_diagnostics,
gauge_deflated_directions,
});
}
MixedPrecisionAttempt::Fallback { reason } => {
log::info!("arrow-Schur mixed precision fallback to f64: {reason}");
mixed_precision_status = MixedPrecisionStatus::F64Fallback;
}
}
}
let (db, sf, diag) =
solve_dense_reduced_system(&schur, &rhs_beta, options, trust_metric_weights)?;
(db, sf, diag)
}
ArrowSolverMode::InexactPCG => {
if options.solve_precision.is_enabled() {
log::info!(
"arrow-Schur mixed precision fallback to f64: InexactPCG does not expose a dense Schur factor for certified f32 refinement"
);
mixed_precision_status = MixedPrecisionStatus::F64Fallback;
}
if options.trust_region.radius == f64::INFINITY {
if let Some(device_data) = sys.device_sae_pcg.as_ref() {
let max_iterations = options
.pcg
.max_iterations
.min(options.trust_region.max_iterations);
let relative_tolerance = options
.pcg
.relative_tolerance
.max(options.trust_region.steihaug_relative_tolerance);
if let Ok((delta, mut diag)) =
crate::gpu::kernels::arrow_schur::solve_sae_matrix_free_pcg(
sys,
device_data.as_ref(),
ridge_t,
ridge_beta,
&rhs_beta,
max_iterations,
relative_tolerance,
)
{
diag.used_device_arrow = true;
return Ok(ArrowNewtonStepArtifacts {
delta_t: back_substitute_delta_t(
sys,
&htt_factors,
delta.view(),
&backend,
),
delta_beta: delta,
htt_factors,
schur_factor: None,
pcg_diagnostics: diag,
gauge_deflated_directions,
});
}
}
}
let (delta, diag) = steihaug_pcg_auto(
sys,
&htt_factors,
ridge_beta,
&rhs_beta,
&options.pcg,
&options.trust_region,
&backend,
options.gpu_matvec.as_ref(),
trust_metric_weights,
options.schur_pd_floor,
)?;
(delta, None, diag)
}
};
if mixed_precision_status != MixedPrecisionStatus::Off {
pcg_diagnostics.mixed_precision_status = mixed_precision_status;
}
let delta_t = back_substitute_delta_t(sys, &htt_factors, delta_beta.view(), &backend);
Ok(ArrowNewtonStepArtifacts {
delta_t,
delta_beta,
htt_factors,
schur_factor,
pcg_diagnostics,
gauge_deflated_directions,
})
}
pub(crate) struct ArrowBlockDiagInverse<'a, B: BatchedBlockSolver> {
pub(crate) sys: &'a ArrowSchurSystem,
pub(crate) backend: &'a B,
pub(crate) htt_factors: ArrowFactorSlab,
pub(crate) schur_factor: Array2<f64>,
}
impl<'a, B: BatchedBlockSolver> ArrowBlockDiagInverse<'a, B> {
pub(crate) fn build(
sys: &'a ArrowSchurSystem,
ridge_t: f64,
ridge_beta: f64,
tolerate_ill_conditioning: bool,
backend: &'a B,
) -> Result<Self, ArrowSchurError>
where
B: Sync,
{
let htt_factors =
backend.factor_blocks(&sys.rows, ridge_t, sys.d, tolerate_ill_conditioning)?;
let schur = build_dense_schur_direct(sys, &htt_factors, ridge_beta, backend)?;
let schur_factor =
cholesky_lower(&schur).map_err(|e| ArrowSchurError::SchurFactorFailed { reason: e })?;
Ok(Self {
sys,
backend,
htt_factors,
schur_factor,
})
}
pub(crate) fn apply(
&self,
r_t: ArrayView1<'_, f64>,
r_beta: ArrayView1<'_, f64>,
) -> (Array1<f64>, Array1<f64>)
where
B: Sync,
{
let sys = self.sys;
let n = sys.rows.len();
let k = sys.k;
let parallel =
n >= SCHUR_MATVEC_PARALLEL_ROW_MIN && rayon::current_thread_index().is_none();
let mut rhs_beta = r_beta.to_owned();
if parallel {
use rayon::prelude::*;
const CHUNK: usize = 64;
let partials: Vec<Array1<f64>> = (0..n)
.into_par_iter()
.chunks(CHUNK)
.map(|idxs| {
let mut acc = Array1::<f64>::zeros(k);
for i in idxs {
let di = sys.row_dims[i];
let base = sys.row_offsets[i];
let r_ti = r_t.slice(ndarray::s![base..base + di]).to_owned();
let u_i = self
.backend
.solve_block_vector(self.htt_factors.factor(i), r_ti.view());
sys_htbeta_accumulate_transpose(sys, i, &sys.rows[i], u_i.view(), &mut acc);
}
acc
})
.collect();
for acc in &partials {
for a in 0..k {
rhs_beta[a] -= acc[a];
}
}
} else {
for i in 0..n {
let di = sys.row_dims[i];
let base = sys.row_offsets[i];
let r_ti = r_t.slice(ndarray::s![base..base + di]).to_owned();
let u_i = self
.backend
.solve_block_vector(self.htt_factors.factor(i), r_ti.view());
let mut acc = Array1::<f64>::zeros(k);
sys_htbeta_accumulate_transpose(sys, i, &sys.rows[i], u_i.view(), &mut acc);
for a in 0..k {
rhs_beta[a] -= acc[a];
}
}
}
let x_beta = cholesky_solve_lower(&self.schur_factor, &rhs_beta);
let total_dt = sys.row_offsets[n];
let mut x_t = Array1::<f64>::zeros(total_dt);
if parallel {
use rayon::prelude::*;
const CHUNK: usize = 64;
let chunks: Vec<(usize, Vec<f64>)> = (0..n)
.into_par_iter()
.chunks(CHUNK)
.map(|idxs| {
let first = idxs[0];
let last = idxs[idxs.len() - 1];
let seg_start = sys.row_offsets[first];
let seg_end = sys.row_offsets[last] + sys.row_dims[last];
let mut seg = vec![0.0_f64; seg_end - seg_start];
for i in idxs {
let di = sys.row_dims[i];
let base = sys.row_offsets[i];
let mut slab = Array1::<f64>::zeros(di);
sys_htbeta_apply_row(sys, i, &sys.rows[i], x_beta.view(), &mut slab);
let mut rhs_i = Array1::<f64>::zeros(di);
for c in 0..di {
rhs_i[c] = r_t[base + c] - slab[c];
}
let xi = self
.backend
.solve_block_vector(self.htt_factors.factor(i), rhs_i.view());
let local = base - seg_start;
for c in 0..di {
seg[local + c] = xi[c];
}
}
(seg_start, seg)
})
.collect();
for (seg_start, seg) in &chunks {
for (o, v) in seg.iter().enumerate() {
x_t[seg_start + o] = *v;
}
}
} else {
let mut htbeta_xb = Array1::<f64>::zeros(sys.d);
for i in 0..n {
let di = sys.row_dims[i];
let base = sys.row_offsets[i];
for c in 0..di {
htbeta_xb[c] = 0.0;
}
let mut slab = htbeta_xb.slice_mut(ndarray::s![..di]).to_owned();
sys_htbeta_apply_row(sys, i, &sys.rows[i], x_beta.view(), &mut slab);
let mut rhs_i = Array1::<f64>::zeros(di);
for c in 0..di {
rhs_i[c] = r_t[base + c] - slab[c];
}
let xi = self
.backend
.solve_block_vector(self.htt_factors.factor(i), rhs_i.view());
for c in 0..di {
x_t[base + c] = xi[c];
}
}
}
(x_t, x_beta)
}
}
#[inline]
fn cross_row_matvec_row_into(
sys: &ArrowSchurSystem,
ridge_t: f64,
i: usize,
x_t: ArrayView1<'_, f64>,
x_beta: ArrayView1<'_, f64>,
seg_start: usize,
seg: &mut [f64],
y_beta_acc: &mut Array1<f64>,
) {
let di = sys.row_dims[i];
let base = sys.row_offsets[i];
let row = &sys.rows[i];
let local = base - seg_start;
for a in 0..di {
let mut acc = ridge_t * x_t[base + a];
for b in 0..di {
acc += row.htt[[a, b]] * x_t[base + b];
}
seg[local + a] = acc;
}
let mut slab = Array1::<f64>::zeros(di);
sys_htbeta_apply_row(sys, i, row, x_beta, &mut slab);
for c in 0..di {
seg[local + c] += slab[c];
}
let x_ti = x_t.slice(ndarray::s![base..base + di]).to_owned();
sys_htbeta_accumulate_transpose(sys, i, row, x_ti.view(), y_beta_acc);
}
pub(crate) fn arrow_cross_row_matvec(
sys: &ArrowSchurSystem,
ridge_t: f64,
ridge_beta: f64,
x_t: ArrayView1<'_, f64>,
x_beta: ArrayView1<'_, f64>,
) -> (Array1<f64>, Array1<f64>) {
let n = sys.rows.len();
let k = sys.k;
let total_dt = sys.row_offsets[n];
let mut y_t = Array1::<f64>::zeros(total_dt);
let mut y_beta = Array1::<f64>::zeros(k);
let parallel = n >= SCHUR_MATVEC_PARALLEL_ROW_MIN && rayon::current_thread_index().is_none();
if parallel {
use rayon::prelude::*;
const CHUNK: usize = 64;
let chunks: Vec<(usize, Vec<f64>, Array1<f64>)> = (0..n)
.into_par_iter()
.chunks(CHUNK)
.map(|idxs| {
let first = idxs[0];
let last = idxs[idxs.len() - 1];
let seg_start = sys.row_offsets[first];
let seg_end = sys.row_offsets[last] + sys.row_dims[last];
let mut seg = vec![0.0_f64; seg_end - seg_start];
let mut acc = Array1::<f64>::zeros(k);
for i in idxs {
cross_row_matvec_row_into(
sys, ridge_t, i, x_t, x_beta, seg_start, &mut seg, &mut acc,
);
}
(seg_start, seg, acc)
})
.collect();
for (seg_start, seg, acc) in &chunks {
for (o, v) in seg.iter().enumerate() {
y_t[seg_start + o] = *v;
}
for j in 0..k {
y_beta[j] += acc[j];
}
}
} else {
let y_t_slice = y_t.as_slice_mut().expect("y_t contiguous");
for i in 0..n {
cross_row_matvec_row_into(sys, ridge_t, i, x_t, x_beta, 0, y_t_slice, &mut y_beta);
}
}
{
let x_beta_slice = x_beta.as_slice().expect("x_beta contiguous");
let y_beta_slice = y_beta.as_slice_mut().expect("y_beta contiguous");
sys.penalty_matvec_add(x_beta_slice, y_beta_slice);
}
for a in 0..k {
y_beta[a] += ridge_beta * x_beta[a];
}
sys.apply_cross_row_penalty_hessian(x_t, &mut y_t);
(y_t, y_beta)
}
pub(crate) fn solve_arrow_newton_step_cross_row(
sys: &ArrowSchurSystem,
ridge_t: f64,
ridge_beta: f64,
options: &ArrowSolveOptions,
) -> Result<ArrowNewtonStepArtifacts, ArrowSchurError> {
let backend = CpuBatchedBlockSolver;
let precond = ArrowBlockDiagInverse::build(
sys,
ridge_t,
ridge_beta,
options.tolerate_ill_conditioning,
&backend,
)?;
let n = sys.rows.len();
let k = sys.k;
let total_dt = sys.row_offsets[n];
let mut b_t = Array1::<f64>::zeros(total_dt);
for i in 0..n {
let di = sys.row_dims[i];
let base = sys.row_offsets[i];
for c in 0..di {
b_t[base + c] = -sys.rows[i].gt[c];
}
}
let mut b_beta = Array1::<f64>::zeros(k);
for a in 0..k {
b_beta[a] = -sys.gb[a];
}
let mut x_t = Array1::<f64>::zeros(total_dt);
let mut x_beta = Array1::<f64>::zeros(k);
let mut r_t = b_t.clone();
let mut r_beta = b_beta.clone();
let (mut z_t, mut z_beta) = precond.apply(r_t.view(), r_beta.view());
let mut p_t = z_t.clone();
let mut p_beta = z_beta.clone();
let mut rz = dot2(&r_t, &r_beta, &z_t, &z_beta);
let b_norm = (dot2(&b_t, &b_beta, &b_t, &b_beta)).sqrt();
const CROSS_ROW_CG_ABS_TOL: f64 = 1e-12;
const CROSS_ROW_CG_REL_TOL: f64 = 1e-13;
const CROSS_ROW_CG_MIN_ITER_BUDGET: usize = 64;
const CROSS_ROW_CG_ITER_MULTIPLE: usize = 4;
let tol = CROSS_ROW_CG_ABS_TOL.max(CROSS_ROW_CG_REL_TOL * b_norm);
let max_iter = (total_dt + k).max(CROSS_ROW_CG_MIN_ITER_BUDGET) * CROSS_ROW_CG_ITER_MULTIPLE;
let mut iters = 0usize;
let mut converged = b_norm == 0.0;
while iters < max_iter && !converged {
let (ap_t, ap_beta) =
arrow_cross_row_matvec(sys, ridge_t, ridge_beta, p_t.view(), p_beta.view());
let pap = dot2(&p_t, &p_beta, &ap_t, &ap_beta);
if !(pap.is_finite() && pap > 0.0) {
return Err(ArrowSchurError::PcgFailed {
reason: format!(
"cross-row full-system CG hit non-positive curvature pᵀAp={pap:e}; \
the cross-row penalty Hessian or arrow block is not PD at this iterate"
),
});
}
let alpha = rz / pap;
for i in 0..total_dt {
x_t[i] += alpha * p_t[i];
r_t[i] -= alpha * ap_t[i];
}
for a in 0..k {
x_beta[a] += alpha * p_beta[a];
r_beta[a] -= alpha * ap_beta[a];
}
let r_norm = (dot2(&r_t, &r_beta, &r_t, &r_beta)).sqrt();
iters += 1;
if r_norm <= tol {
converged = true;
break;
}
let (nz_t, nz_beta) = precond.apply(r_t.view(), r_beta.view());
z_t = nz_t;
z_beta = nz_beta;
let rz_new = dot2(&r_t, &r_beta, &z_t, &z_beta);
let beta_cg = rz_new / rz;
for i in 0..total_dt {
p_t[i] = z_t[i] + beta_cg * p_t[i];
}
for a in 0..k {
p_beta[a] = z_beta[a] + beta_cg * p_beta[a];
}
rz = rz_new;
}
if !converged {
let r_norm = (dot2(&r_t, &r_beta, &r_t, &r_beta)).sqrt();
return Err(ArrowSchurError::PcgFailed {
reason: format!(
"cross-row full-system CG did not converge in {iters} iters \
(‖r‖={r_norm:e}, tol={tol:e})"
),
});
}
let final_residual = (dot2(&r_t, &r_beta, &r_t, &r_beta)).sqrt();
let diag = PcgDiagnostics {
iterations: iters,
matvec_calls: iters,
precond_apply_calls: iters + 1,
ridge_escalations: 0,
final_relative_residual: if b_norm > 0.0 {
final_residual / b_norm
} else {
0.0
},
stopping_reason: PcgStopReason::Converged,
mixed_precision_status: MixedPrecisionStatus::Off,
used_device_arrow: false,
injected_host_procedural_matvec: false,
};
Ok(ArrowNewtonStepArtifacts {
delta_t: x_t,
delta_beta: x_beta,
htt_factors: precond.htt_factors,
schur_factor: Some(precond.schur_factor),
pcg_diagnostics: diag,
gauge_deflated_directions: 0,
})
}
pub(crate) fn dot2(
a_t: &Array1<f64>,
a_beta: &Array1<f64>,
b_t: &Array1<f64>,
b_beta: &Array1<f64>,
) -> f64 {
let mut acc = 0.0_f64;
for i in 0..a_t.len() {
acc += a_t[i] * b_t[i];
}
for a in 0..a_beta.len() {
acc += a_beta[a] * b_beta[a];
}
acc
}
pub(crate) fn cholesky_solve_lower(l: &Array2<f64>, b: &Array1<f64>) -> Array1<f64> {
let n = l.nrows();
assert!(
(0..n).all(|i| l[[i, i]].is_finite() && l[[i, i]].abs() >= f64::MIN_POSITIVE),
"cholesky_solve_lower: factor diagonal must be finite and non-subnormal"
);
let mut y = Array1::<f64>::zeros(n);
for i in 0..n {
let mut sum = b[i];
for j in 0..i {
sum -= l[[i, j]] * y[j];
}
y[i] = sum / l[[i, i]];
}
let mut x = Array1::<f64>::zeros(n);
for i in (0..n).rev() {
let mut sum = y[i];
for j in (i + 1)..n {
sum -= l[[j, i]] * x[j];
}
x[i] = sum / l[[i, i]];
}
x
}
pub(crate) fn reduced_rhs_beta<B: BatchedBlockSolver + Sync>(
sys: &ArrowSchurSystem,
htt_factors: &ArrowFactorSlab,
backend: &B,
) -> Array1<f64> {
let k = sys.k;
let n = sys.rows.len();
let mut rhs_beta = Array1::<f64>::zeros(k);
let parallel = n >= SCHUR_MATVEC_PARALLEL_ROW_MIN && rayon::current_thread_index().is_none();
if parallel {
use rayon::prelude::*;
const CHUNK: usize = 64;
let partials: Vec<Array1<f64>> = (0..n)
.into_par_iter()
.chunks(CHUNK)
.map(|idxs| {
let mut acc = Array1::<f64>::zeros(k);
for i in idxs {
let row = &sys.rows[i];
let v = backend.solve_block_vector(htt_factors.factor(i), row.gt.view());
sys_htbeta_accumulate_transpose(sys, i, row, v.view(), &mut acc);
}
acc
})
.collect();
for acc in &partials {
for j in 0..k {
rhs_beta[j] += acc[j];
}
}
} else {
for (i, row) in sys.rows.iter().enumerate() {
let v = backend.solve_block_vector(htt_factors.factor(i), row.gt.view());
sys_htbeta_accumulate_transpose(sys, i, row, v.view(), &mut rhs_beta);
}
}
for j in 0..k {
rhs_beta[j] -= sys.gb[j];
}
rhs_beta
}
#[derive(Clone, Copy, Debug)]
pub(crate) enum SchurReductionKind {
Direct,
SqrtBa,
}
#[inline]
pub(crate) fn row_schur_contribution_factors<B: BatchedBlockSolver>(
sys: &ArrowSchurSystem,
row_idx: usize,
row: &ArrowRowBlock,
htt_factor: ArrayView2<'_, f64>,
backend: &B,
kind: SchurReductionKind,
) -> Result<(Array2<f64>, Array2<f64>), ArrowSchurError> {
let htbeta = sys_htbeta_materialize_row(sys, row_idx, row)?;
match kind {
SchurReductionKind::Direct => {
let solved = backend.solve_block_matrix(htt_factor, htbeta.view());
Ok((htbeta, solved))
}
SchurReductionKind::SqrtBa => {
let whitened = backend.sqrt_solve_block_matrix(htt_factor, htbeta.view());
Ok((whitened.clone(), whitened))
}
}
}
#[inline]
pub(crate) fn subtract_row_schur_contribution<B: BatchedBlockSolver>(
sys: &ArrowSchurSystem,
row_idx: usize,
row: &ArrowRowBlock,
htt_factor: ArrayView2<'_, f64>,
backend: &B,
kind: SchurReductionKind,
schur: &mut Array2<f64>,
) -> Result<(), ArrowSchurError> {
let (left, right) =
row_schur_contribution_factors(sys, row_idx, row, htt_factor, backend, kind)?;
backend.block_gemm_subtract(schur, &left, &right);
Ok(())
}