use super::*;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ArrowSolverMode {
Direct,
SqrtBA,
InexactPCG,
}
impl ArrowSolverMode {
pub const fn automatic(k: usize) -> Self {
if k <= DIRECT_SOLVE_MAX_K {
Self::Direct
} else {
Self::InexactPCG
}
}
pub const fn automatic_for_single_precision(k: usize) -> Self {
if k <= DIRECT_SOLVE_MAX_K {
Self::SqrtBA
} else {
Self::InexactPCG
}
}
}
#[derive(Debug, Default, Clone, Copy, PartialEq, Eq)]
pub enum PcgStopReason {
#[default]
Converged,
MaxIter,
TrustRegion,
Indefinite,
Stagnation,
}
#[derive(Debug, Default, Clone, Copy)]
pub struct PcgDiagnostics {
pub iterations: usize,
pub matvec_calls: usize,
pub precond_apply_calls: usize,
pub ridge_escalations: usize,
pub final_relative_residual: f64,
pub stopping_reason: PcgStopReason,
pub mixed_precision_status: MixedPrecisionStatus,
pub used_device_arrow: bool,
pub injected_host_procedural_matvec: bool,
}
#[derive(Debug, Default, Clone, Copy, PartialEq, Eq)]
pub enum MixedPrecisionStatus {
#[default]
Off,
Certified { refinement_steps: usize },
F64Fallback,
}
#[derive(Debug, Clone)]
pub struct ArrowPcgOptions {
pub max_iterations: usize,
pub relative_tolerance: f64,
}
impl Default for ArrowPcgOptions {
fn default() -> Self {
Self {
max_iterations: DEFAULT_PCG_MAX_ITERATIONS,
relative_tolerance: DEFAULT_PCG_RELATIVE_TOLERANCE,
}
}
}
#[derive(Debug, Clone)]
pub struct ArrowTrustRegionOptions {
pub radius: f64,
pub steihaug_relative_tolerance: f64,
pub max_iterations: usize,
}
impl Default for ArrowTrustRegionOptions {
fn default() -> Self {
Self {
radius: DEFAULT_TRUST_REGION_RADIUS,
steihaug_relative_tolerance: DEFAULT_PCG_RELATIVE_TOLERANCE,
max_iterations: DEFAULT_PCG_MAX_ITERATIONS,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum ArrowSolvePrecisionPolicy {
F64Only,
CertifiedMixed {
max_refinement_steps: usize,
residual_relative_tolerance: f64,
kappa_unit_roundoff_margin: f64,
},
}
impl Default for ArrowSolvePrecisionPolicy {
fn default() -> Self {
Self::F64Only
}
}
impl ArrowSolvePrecisionPolicy {
pub fn certified_mixed() -> Self {
Self::CertifiedMixed {
max_refinement_steps: DEFAULT_MIXED_PRECISION_MAX_REFINEMENTS,
residual_relative_tolerance: DEFAULT_MIXED_PRECISION_CERTIFICATE_TOLERANCE,
kappa_unit_roundoff_margin: DEFAULT_MIXED_PRECISION_KAPPA_MARGIN,
}
}
pub(crate) fn is_enabled(self) -> bool {
matches!(self, ArrowSolvePrecisionPolicy::CertifiedMixed { .. })
}
}
#[derive(Clone)]
pub struct ArrowSolveOptions {
pub mode: ArrowSolverMode,
pub pcg: ArrowPcgOptions,
pub trust_region: ArrowTrustRegionOptions,
pub streaming_chunk_size: Option<usize>,
pub riemannian_trust_region: bool,
pub gpu_matvec: Option<GpuSchurMatvec>,
pub tolerate_ill_conditioning: bool,
pub solve_precision: ArrowSolvePrecisionPolicy,
pub schur_pd_floor: Option<f64>,
}
impl std::fmt::Debug for ArrowSolveOptions {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ArrowSolveOptions")
.field("mode", &self.mode)
.field("pcg", &self.pcg)
.field("trust_region", &self.trust_region)
.field("streaming_chunk_size", &self.streaming_chunk_size)
.field("riemannian_trust_region", &self.riemannian_trust_region)
.field("gpu_matvec", &self.gpu_matvec.is_some())
.field("tolerate_ill_conditioning", &self.tolerate_ill_conditioning)
.field("solve_precision", &self.solve_precision)
.field("schur_pd_floor", &self.schur_pd_floor)
.finish()
}
}
#[derive(Debug, Clone)]
pub struct ArrowProximalCorrectionOptions {
pub initial_ridge: f64,
pub ridge_growth: f64,
pub max_attempts: usize,
pub armijo_c1: f64,
pub gradient_tolerance: f64,
pub convergence_objective_rel_tol: f64,
}
impl Default for ArrowProximalCorrectionOptions {
fn default() -> Self {
Self {
initial_ridge: DEFAULT_PROXIMAL_INITIAL_RIDGE,
ridge_growth: DEFAULT_PROXIMAL_RIDGE_GROWTH,
max_attempts: DEFAULT_PROXIMAL_MAX_ATTEMPTS,
armijo_c1: DEFAULT_ARMIJO_C1,
gradient_tolerance: DEFAULT_GRADIENT_TOLERANCE,
convergence_objective_rel_tol: DEFAULT_PROXIMAL_CONVERGENCE_REL_TOL,
}
}
}
#[derive(Debug, Clone)]
pub struct ArrowAcceptedProximalStep {
pub delta_t: Array1<f64>,
pub delta_beta: Array1<f64>,
pub ridge_t: f64,
pub ridge_beta: f64,
pub proximal_ridge: f64,
pub objective_value: f64,
pub trial_objective_value: f64,
pub gradient_dot_step: f64,
pub attempts: usize,
}
impl ArrowSolveOptions {
pub fn automatic(k: usize) -> Self {
Self {
mode: ArrowSolverMode::automatic(k),
pcg: ArrowPcgOptions::default(),
trust_region: ArrowTrustRegionOptions::default(),
streaming_chunk_size: None,
riemannian_trust_region: false,
gpu_matvec: None,
tolerate_ill_conditioning: false,
solve_precision: ArrowSolvePrecisionPolicy::F64Only,
schur_pd_floor: None,
}
}
pub fn direct() -> Self {
Self {
mode: ArrowSolverMode::Direct,
pcg: ArrowPcgOptions::default(),
trust_region: ArrowTrustRegionOptions::default(),
streaming_chunk_size: None,
riemannian_trust_region: false,
gpu_matvec: None,
tolerate_ill_conditioning: false,
solve_precision: ArrowSolvePrecisionPolicy::F64Only,
schur_pd_floor: None,
}
}
pub fn sqrt_ba() -> Self {
Self {
mode: ArrowSolverMode::SqrtBA,
pcg: ArrowPcgOptions::default(),
trust_region: ArrowTrustRegionOptions::default(),
streaming_chunk_size: None,
riemannian_trust_region: false,
gpu_matvec: None,
tolerate_ill_conditioning: false,
solve_precision: ArrowSolvePrecisionPolicy::F64Only,
schur_pd_floor: None,
}
}
pub fn inexact_pcg() -> Self {
Self {
mode: ArrowSolverMode::InexactPCG,
pcg: ArrowPcgOptions::default(),
trust_region: ArrowTrustRegionOptions::default(),
streaming_chunk_size: None,
riemannian_trust_region: false,
gpu_matvec: None,
tolerate_ill_conditioning: false,
solve_precision: ArrowSolvePrecisionPolicy::F64Only,
schur_pd_floor: None,
}
}
pub fn with_streaming_chunk_size(mut self, chunk_size: Option<usize>) -> Self {
self.streaming_chunk_size = chunk_size.filter(|&chunk| chunk > 0);
self
}
pub fn with_ill_conditioning_tolerated(mut self) -> Self {
self.tolerate_ill_conditioning = true;
self
}
pub fn with_solve_precision_policy(mut self, policy: ArrowSolvePrecisionPolicy) -> Self {
self.solve_precision = policy;
self
}
#[must_use]
pub fn with_streaming_solve_precision_default(&self) -> Self {
let mut out = self.clone();
if matches!(out.solve_precision, ArrowSolvePrecisionPolicy::F64Only) {
out.solve_precision = ArrowSolvePrecisionPolicy::certified_mixed();
}
out
}
}
pub trait BatchedBlockSolver {
fn factor_blocks(
&self,
rows: &[ArrowRowBlock],
ridge_t: f64,
d: usize,
tolerate_ill_conditioning: bool,
) -> Result<ArrowFactorSlab, ArrowSchurError>;
fn solve_block_vector(
&self,
factor: ArrayView2<'_, f64>,
rhs: ArrayView1<'_, f64>,
) -> Array1<f64>;
fn solve_block_matrix(
&self,
factor: ArrayView2<'_, f64>,
rhs: ArrayView2<'_, f64>,
) -> Array2<f64>;
fn sqrt_solve_block_matrix(
&self,
factor: ArrayView2<'_, f64>,
rhs: ArrayView2<'_, f64>,
) -> Array2<f64>;
fn block_gemm_subtract(&self, schur: &mut Array2<f64>, left: &Array2<f64>, right: &Array2<f64>);
}
#[derive(Debug, Clone)]
pub struct ArrowRowGaugeDeflation {
pub directions: Arc<[Vec<Array1<f64>>]>,
}
impl ArrowRowGaugeDeflation {
pub fn new(directions: Vec<Vec<Array1<f64>>>) -> Self {
Self {
directions: Arc::from(directions.into_boxed_slice()),
}
}
pub(crate) fn row(&self, row: usize) -> &[Array1<f64>] {
self.directions.get(row).map(Vec::as_slice).unwrap_or(&[])
}
}
#[derive(Debug, Clone, Copy, Default)]
pub struct CpuBatchedBlockSolver;
impl BatchedBlockSolver for CpuBatchedBlockSolver {
fn factor_blocks(
&self,
rows: &[ArrowRowBlock],
ridge_t: f64,
d: usize,
tolerate_ill_conditioning: bool,
) -> Result<ArrowFactorSlab, ArrowSchurError> {
if let Some(batched) =
try_factor_blocks_batched(rows, ridge_t, d, tolerate_ill_conditioning)
{
return Ok(batched);
}
let mut out = Vec::with_capacity(rows.len());
for (row_idx, row) in rows.iter().enumerate() {
out.push(factor_one_row(
row,
ridge_t,
d,
row_idx,
tolerate_ill_conditioning,
)?);
}
Ok(ArrowFactorSlab::from_blocks(out))
}
fn solve_block_vector(
&self,
factor: ArrayView2<'_, f64>,
rhs: ArrayView1<'_, f64>,
) -> Array1<f64> {
match (factor.nrows(), factor.ncols(), rhs.len()) {
(1, 1, 1) => cholesky_solve_vector_fixed::<1>(factor, rhs),
(2, 2, 2) => cholesky_solve_vector_fixed::<2>(factor, rhs),
(3, 3, 3) => cholesky_solve_vector_fixed::<3>(factor, rhs),
(4, 4, 4) => cholesky_solve_vector_fixed::<4>(factor, rhs),
_ => cholesky_solve_vector(factor, rhs),
}
}
fn solve_block_matrix(
&self,
factor: ArrayView2<'_, f64>,
rhs: ArrayView2<'_, f64>,
) -> Array2<f64> {
cholesky_solve_matrix(factor, rhs)
}
fn sqrt_solve_block_matrix(
&self,
factor: ArrayView2<'_, f64>,
rhs: ArrayView2<'_, f64>,
) -> Array2<f64> {
forward_substitution_lower_matrix(factor, rhs)
}
fn block_gemm_subtract(
&self,
schur: &mut Array2<f64>,
left: &Array2<f64>,
right: &Array2<f64>,
) {
let k = schur.nrows();
let d = left.nrows();
assert_eq!(left.ncols(), k);
assert_eq!(right.ncols(), k);
assert_eq!(schur.ncols(), k);
for c in 0..d {
for a in 0..k {
let lca = left[[c, a]];
if lca == 0.0 {
continue;
}
for b in 0..k {
schur[[a, b]] -= lca * right[[c, b]];
}
}
}
}
}