use super::*;
#[derive(Debug, Clone)]
pub struct ArrowRowBlock {
pub htt: Array2<f64>,
pub htbeta: Array2<f64>,
pub gt: Array1<f64>,
}
impl ArrowRowBlock {
pub fn new(d: usize, k: usize) -> Self {
Self::new_with_htbeta_cols(d, k)
}
pub fn new_with_htbeta_cols(d: usize, htbeta_cols: usize) -> Self {
Self {
htt: Array2::<f64>::zeros((d, d)),
htbeta: Array2::<f64>::zeros((d, htbeta_cols)),
gt: Array1::<f64>::zeros(d),
}
}
}
pub struct ArrowSchurSystem {
pub rows: Vec<ArrowRowBlock>,
pub hbb: Array2<f64>,
pub hbb_matvec: Option<SharedBetaMatvec>,
pub htbeta_matvec: Option<RowHtbetaMatvec>,
pub htbeta_transpose_matvec: Option<RowHtbetaTransposeMatvec>,
pub htbeta_dense_supplement: bool,
pub hbb_diag: Option<Array1<f64>>,
pub gb: Array1<f64>,
pub d: usize,
pub row_dims: Arc<[usize]>,
pub row_offsets: Arc<[usize]>,
pub k: usize,
pub manifold_mode_fingerprint: u64,
pub row_hessian_fingerprint: u64,
pub analytic_row_hessian_fingerprint: u64,
pub block_offsets: Arc<[Range<usize>]>,
pub penalty_op: Option<Arc<dyn BetaPenaltyOp>>,
pub device_sae_pcg: Option<Arc<DeviceSaePcgData>>,
pub cross_row_penalties: Vec<CrossRowLatentPenalty>,
pub row_gauge_deflation: Option<ArrowRowGaugeDeflation>,
pub ibp_cross_row: Option<IbpCrossRowSource>,
}
impl Clone for ArrowSchurSystem {
fn clone(&self) -> Self {
Self {
rows: self.rows.clone(),
hbb: self.hbb.clone(),
hbb_matvec: self.hbb_matvec.clone(),
htbeta_matvec: self.htbeta_matvec.clone(),
htbeta_transpose_matvec: self.htbeta_transpose_matvec.clone(),
htbeta_dense_supplement: self.htbeta_dense_supplement,
hbb_diag: self.hbb_diag.clone(),
gb: self.gb.clone(),
d: self.d,
row_dims: Arc::clone(&self.row_dims),
row_offsets: Arc::clone(&self.row_offsets),
k: self.k,
manifold_mode_fingerprint: self.manifold_mode_fingerprint,
row_hessian_fingerprint: self.row_hessian_fingerprint,
analytic_row_hessian_fingerprint: self.analytic_row_hessian_fingerprint,
block_offsets: Arc::clone(&self.block_offsets),
penalty_op: self.penalty_op.clone(),
device_sae_pcg: self.device_sae_pcg.clone(),
cross_row_penalties: self.cross_row_penalties.clone(),
row_gauge_deflation: self.row_gauge_deflation.clone(),
ibp_cross_row: self.ibp_cross_row.clone(),
}
}
}
#[derive(Clone)]
pub struct CrossRowLatentPenalty {
pub penalty: AnalyticPenaltyKind,
pub rho_local: Array1<f64>,
pub target_t: Array1<f64>,
}
#[derive(Clone, Debug, Default)]
pub struct IbpCrossRowSource {
pub r: usize,
pub d: Array1<f64>,
pub entries: Vec<(usize, usize, f64)>,
}
impl IbpCrossRowSource {
pub(crate) fn dense_u(&self, delta_t_len: usize) -> Array2<f64> {
let mut u = Array2::<f64>::zeros((delta_t_len, self.r));
for &(g, k, z) in &self.entries {
u[[g, k]] += z;
}
u
}
pub(crate) fn self_term_downdate(&self, delta_t_len: usize) -> Array1<f64> {
let mut down = Array1::<f64>::zeros(delta_t_len);
for &(g, k, z) in &self.entries {
down[g] += self.d[k] * z * z;
}
down
}
}
impl ArrowSchurSystem {
pub fn new(n: usize, d: usize, k: usize) -> Self {
Self::new_with_hbb(n, d, k, Array2::<f64>::zeros((k, k)))
}
pub fn new_with_empty_hbb_and_htbeta_cols(
n: usize,
d: usize,
k: usize,
htbeta_cols: usize,
) -> Self {
let rows = (0..n)
.map(|_| ArrowRowBlock::new_with_htbeta_cols(d, htbeta_cols))
.collect();
let row_dims: Arc<[usize]> = (0..n).map(|_| d).collect::<Vec<_>>().into();
let row_offsets: Arc<[usize]> = (0..=n).map(|i| i * d).collect::<Vec<_>>().into();
Self {
rows,
hbb: Array2::<f64>::zeros((0, 0)),
hbb_matvec: None,
htbeta_matvec: None,
htbeta_transpose_matvec: None,
htbeta_dense_supplement: false,
hbb_diag: None,
gb: Array1::<f64>::zeros(k),
d,
row_dims,
row_offsets,
k,
manifold_mode_fingerprint: EUCLIDEAN_MANIFOLD_MODE_FINGERPRINT,
row_hessian_fingerprint: 0,
analytic_row_hessian_fingerprint: 0,
block_offsets: Arc::from([] as [Range<usize>; 0]),
penalty_op: None,
device_sae_pcg: None,
cross_row_penalties: Vec::new(),
row_gauge_deflation: None,
ibp_cross_row: None,
}
}
pub fn new_with_hbb(n: usize, d: usize, k: usize, hbb: Array2<f64>) -> Self {
Self::new_with_hbb_and_htbeta_cols(n, d, k, hbb, k)
}
pub fn new_with_hbb_and_htbeta_cols(
n: usize,
d: usize,
k: usize,
mut hbb: Array2<f64>,
htbeta_cols: usize,
) -> Self {
assert_eq!(hbb.dim(), (k, k));
hbb.fill(0.0);
let rows = (0..n)
.map(|_| ArrowRowBlock::new_with_htbeta_cols(d, htbeta_cols))
.collect();
let row_dims: Arc<[usize]> = (0..n).map(|_| d).collect::<Vec<_>>().into();
let row_offsets: Arc<[usize]> = (0..=n).map(|i| i * d).collect::<Vec<_>>().into();
Self {
rows,
hbb,
hbb_matvec: None,
htbeta_matvec: None,
htbeta_transpose_matvec: None,
htbeta_dense_supplement: false,
hbb_diag: None,
gb: Array1::<f64>::zeros(k),
d,
row_dims,
row_offsets,
k,
manifold_mode_fingerprint: EUCLIDEAN_MANIFOLD_MODE_FINGERPRINT,
row_hessian_fingerprint: 0,
analytic_row_hessian_fingerprint: 0,
block_offsets: Arc::from([] as [Range<usize>; 0]),
penalty_op: None,
device_sae_pcg: None,
cross_row_penalties: Vec::new(),
row_gauge_deflation: None,
ibp_cross_row: None,
}
}
pub fn new_matrix_free_shared<F>(
n: usize,
d: usize,
k: usize,
matvec: F,
diag: Array1<f64>,
) -> Self
where
F: for<'a> Fn(ArrayView1<'a, f64>, &mut Array1<f64>) + Send + Sync + 'static,
{
assert_eq!(diag.len(), k);
let rows = (0..n).map(|_| ArrowRowBlock::new(d, k)).collect();
let row_dims: Arc<[usize]> = (0..n).map(|_| d).collect::<Vec<_>>().into();
let row_offsets: Arc<[usize]> = (0..=n).map(|i| i * d).collect::<Vec<_>>().into();
let matvec_arc: SharedBetaMatvec = Arc::new(matvec);
let penalty_op: Option<Arc<dyn BetaPenaltyOp>> = Some(Arc::new(MatvecDiagPenaltyOp::new(
k,
Arc::clone(&matvec_arc),
diag.clone(),
)));
Self {
rows,
hbb: Array2::<f64>::zeros((0, 0)),
hbb_matvec: Some(matvec_arc),
htbeta_matvec: None,
htbeta_transpose_matvec: None,
htbeta_dense_supplement: false,
hbb_diag: Some(diag),
gb: Array1::<f64>::zeros(k),
d,
row_dims,
row_offsets,
k,
manifold_mode_fingerprint: EUCLIDEAN_MANIFOLD_MODE_FINGERPRINT,
row_hessian_fingerprint: 0,
analytic_row_hessian_fingerprint: 0,
block_offsets: Arc::from([] as [Range<usize>; 0]),
penalty_op,
device_sae_pcg: None,
cross_row_penalties: Vec::new(),
row_gauge_deflation: None,
ibp_cross_row: None,
}
}
pub fn new_with_per_row_dims_empty_hbb_and_htbeta_cols(
per_row_dims: Vec<usize>,
k: usize,
htbeta_cols: usize,
) -> Self {
let n = per_row_dims.len();
let d = per_row_dims.iter().copied().max().unwrap_or(0);
let mut offsets = Vec::with_capacity(n + 1);
let mut cursor = 0usize;
offsets.push(cursor);
for &dim in &per_row_dims {
cursor += dim;
offsets.push(cursor);
}
let rows = per_row_dims
.iter()
.map(|&dim| ArrowRowBlock::new_with_htbeta_cols(dim, htbeta_cols))
.collect();
Self {
rows,
hbb: Array2::<f64>::zeros((0, 0)),
hbb_matvec: None,
htbeta_matvec: None,
htbeta_transpose_matvec: None,
htbeta_dense_supplement: false,
hbb_diag: None,
gb: Array1::<f64>::zeros(k),
d,
row_dims: Arc::from(per_row_dims.into_boxed_slice()),
row_offsets: Arc::from(offsets.into_boxed_slice()),
k,
manifold_mode_fingerprint: EUCLIDEAN_MANIFOLD_MODE_FINGERPRINT,
row_hessian_fingerprint: 0,
analytic_row_hessian_fingerprint: 0,
block_offsets: Arc::from([] as [Range<usize>; 0]),
penalty_op: None,
device_sae_pcg: None,
cross_row_penalties: Vec::new(),
row_gauge_deflation: None,
ibp_cross_row: None,
}
}
pub fn new_with_per_row_dims_and_hbb_and_htbeta_cols(
per_row_dims: Vec<usize>,
k: usize,
mut hbb: Array2<f64>,
htbeta_cols: usize,
) -> Self {
assert_eq!(hbb.dim(), (k, k));
hbb.fill(0.0);
let n = per_row_dims.len();
let max_d = per_row_dims.iter().copied().max().unwrap_or(0);
let row_dims: Arc<[usize]> = per_row_dims.iter().copied().collect::<Vec<_>>().into();
let mut off_vec = Vec::with_capacity(n + 1);
let mut cursor = 0usize;
for &di in &per_row_dims {
off_vec.push(cursor);
cursor += di;
}
off_vec.push(cursor);
let row_offsets: Arc<[usize]> = off_vec.into();
let rows = per_row_dims
.iter()
.map(|&di| ArrowRowBlock::new_with_htbeta_cols(di, htbeta_cols))
.collect();
Self {
rows,
hbb,
hbb_matvec: None,
htbeta_matvec: None,
htbeta_transpose_matvec: None,
htbeta_dense_supplement: false,
hbb_diag: None,
gb: Array1::<f64>::zeros(k),
d: max_d,
row_dims,
row_offsets,
k,
manifold_mode_fingerprint: EUCLIDEAN_MANIFOLD_MODE_FINGERPRINT,
row_hessian_fingerprint: 0,
analytic_row_hessian_fingerprint: 0,
block_offsets: Arc::from([] as [Range<usize>; 0]),
penalty_op: None,
device_sae_pcg: None,
cross_row_penalties: Vec::new(),
row_gauge_deflation: None,
ibp_cross_row: None,
}
}
pub fn set_row_gauge_deflation(&mut self, deflation: ArrowRowGaugeDeflation) {
self.row_gauge_deflation = Some(deflation);
}
pub fn set_ibp_cross_row_source(&mut self, source: IbpCrossRowSource) {
if source.r == 0 || source.entries.is_empty() {
self.ibp_cross_row = None;
} else {
self.ibp_cross_row = Some(source);
}
}
pub fn n(&self) -> usize {
self.rows.len()
}
pub fn compute_row_hessian_fingerprint(&self) -> u64 {
row_hessian_fingerprint_for_system(self)
}
pub fn current_row_hessian_fingerprint(&self) -> u64 {
combine_row_and_registry_fingerprints(
self.compute_row_hessian_fingerprint(),
self.analytic_row_hessian_fingerprint,
)
}
pub fn refresh_row_hessian_fingerprint(&mut self) {
self.row_hessian_fingerprint = self.current_row_hessian_fingerprint();
}
pub fn set_shared_beta_operator<F>(&mut self, matvec: F, diag: Array1<f64>)
where
F: for<'a> Fn(ArrayView1<'a, f64>, &mut Array1<f64>) + Send + Sync + 'static,
{
assert_eq!(diag.len(), self.k);
let matvec_arc: SharedBetaMatvec = Arc::new(matvec);
self.penalty_op = Some(Arc::new(MatvecDiagPenaltyOp::new(
self.k,
Arc::clone(&matvec_arc),
diag.clone(),
)));
self.hbb_matvec = Some(matvec_arc);
self.hbb_diag = Some(diag);
}
pub fn activate_dense_htbeta_supplement(&mut self) {
self.htbeta_dense_supplement = true;
}
pub fn set_row_htbeta_operator<F, T>(&mut self, forward: F, transpose: T)
where
F: for<'a> Fn(usize, ArrayView1<'a, f64>, &mut Array1<f64>) + Send + Sync + 'static,
T: for<'a> Fn(usize, ArrayView1<'a, f64>, &mut Array1<f64>) + Send + Sync + 'static,
{
self.htbeta_matvec = Some(Arc::new(forward));
self.htbeta_transpose_matvec = Some(Arc::new(transpose));
}
pub fn set_block_offsets(&mut self, offsets: Arc<[Range<usize>]>) {
self.block_offsets = offsets;
}
pub fn set_penalty_op(&mut self, op: Arc<dyn BetaPenaltyOp>) {
self.penalty_op = Some(op);
}
pub fn set_device_sae_pcg_data(&mut self, data: DeviceSaePcgData) {
assert_eq!(data.beta_dim, self.k);
if data.frame.is_none() {
assert_eq!(data.a_phi.len(), self.rows.len());
assert_eq!(data.local_jac.len(), self.rows.len());
}
self.device_sae_pcg = Some(Arc::new(data));
}
pub fn effective_penalty_op(&self) -> Arc<dyn BetaPenaltyOp> {
match self.penalty_op.as_ref() {
Some(op) => Arc::clone(op),
None => Arc::new(DensePenaltyOp(self.hbb.clone())),
}
}
#[inline]
pub(crate) fn penalty_matvec_add(&self, x: &[f64], y: &mut [f64]) {
if let Some(op) = self.penalty_op.as_ref() {
op.matvec(x, y);
} else {
let k = self.hbb.nrows();
for a in 0..k {
let mut acc = 0.0_f64;
for b in 0..k {
acc += self.hbb[[a, b]] * x[b];
}
y[a] += acc;
}
}
}
pub(crate) fn penalty_ridge_prologue_into(
&self,
x: &[f64],
ridge: f64,
y: &mut [f64],
parallel: bool,
) {
let k = self.hbb.nrows();
let dense_parallel = parallel
&& self.penalty_op.is_none()
&& self.hbb.dim() == (k, k)
&& k >= SCHUR_PROLOGUE_PARALLEL_K_MIN;
if dense_parallel {
use rayon::prelude::*;
let hbb = &self.hbb;
y.par_iter_mut().enumerate().for_each(|(a, ya)| {
let mut acc = 0.0_f64;
for b in 0..k {
acc += hbb[[a, b]] * x[b];
}
*ya = acc + ridge * x[a];
});
} else {
self.penalty_matvec_add(x, y);
for a in 0..k {
y[a] += ridge * x[a];
}
}
}
#[inline]
pub(crate) fn penalty_diagonal_add(&self, diag: &mut [f64]) {
if let Some(op) = self.penalty_op.as_ref() {
op.diagonal(diag);
} else if let Some(hbb_diag) = self.hbb_diag.as_ref() {
let k = hbb_diag.len().min(diag.len());
for j in 0..k {
diag[j] += hbb_diag[j];
}
} else {
let k = self.hbb.nrows().min(diag.len());
for j in 0..k {
diag[j] += self.hbb[[j, j]];
}
}
}
#[inline]
pub(crate) fn penalty_block_add(
&self,
id: BetaBlockId,
offsets: &[Range<usize>],
out: &mut Array2<f64>,
) {
if let Some(op) = self.penalty_op.as_ref() {
op.block(id, offsets, out);
} else {
let range = &offsets[id.0];
let b = range.end - range.start;
if self.hbb.dim() == (self.k, self.k) {
for bi in 0..b {
for bj in 0..b {
out[[bi, bj]] += self.hbb[[range.start + bi, range.start + bj]];
}
}
} else if let Some(hbb_diag) = self.hbb_diag.as_ref() {
for bi in 0..b {
out[[bi, bi]] += hbb_diag[range.start + bi];
}
}
}
}
#[inline]
pub(crate) fn penalty_subblock_add(&self, cols: &[usize], out: &mut Array2<f64>) {
let b = cols.len();
if let Some(op) = self.penalty_op.as_ref() {
let mut probe = Array1::<f64>::zeros(self.k);
let mut result = Array1::<f64>::zeros(self.k);
for bj in 0..b {
probe.fill(0.0);
probe[cols[bj]] = 1.0;
result.fill(0.0);
{
let p_slice = probe.as_slice().expect("probe contiguous");
let r_slice = result.as_slice_mut().expect("result contiguous");
op.matvec(p_slice, r_slice);
}
for bi in 0..b {
out[[bi, bj]] += result[cols[bi]];
}
}
} else if self.hbb.dim() == (self.k, self.k) {
for bi in 0..b {
for bj in 0..b {
out[[bi, bj]] += self.hbb[[cols[bi], cols[bj]]];
}
}
} else if let Some(hbb_diag) = self.hbb_diag.as_ref() {
for bi in 0..b {
out[[bi, bi]] += hbb_diag[cols[bi]];
}
}
}
pub fn add_analytic_penalty_contributions(
&mut self,
registry: &AnalyticPenaltyRegistry,
target_t: ArrayView1<'_, f64>,
target_beta: ArrayView1<'_, f64>,
rho_global: ArrayView1<'_, f64>,
) -> Result<(), ArrowSchurError> {
let layout = registry.rho_layout();
let mut penalty_fingerprints = Vec::new();
self.cross_row_penalties.clear();
for (penalty, (rho_slice, tier, _name)) in registry.penalties.iter().zip(layout.iter()) {
let rho_local = rho_global.slice(ndarray::s![rho_slice.clone()]);
match tier {
PenaltyTier::Psi => {
if analytic_penalty_is_row_block_diagonal(penalty) {
self.add_ext_coord_penalty(penalty, target_t, rho_local);
if let Some(fingerprint) =
analytic_penalty_row_hessian_fingerprint(penalty, target_t, rho_local)
{
penalty_fingerprints.push(fingerprint);
}
} else {
self.add_ext_coord_penalty_gradient_only(penalty, target_t, rho_local);
self.cross_row_penalties.push(CrossRowLatentPenalty {
penalty: penalty.clone(),
rho_local: rho_local.to_owned(),
target_t: target_t.to_owned(),
});
}
}
PenaltyTier::Beta => {
self.add_beta_penalty(penalty, target_beta, rho_local);
}
PenaltyTier::Rho => {
}
}
}
for cross in &self.cross_row_penalties {
penalty_fingerprints.push(cross_row_penalty_fingerprint(
&cross.penalty,
target_t,
cross.rho_local.view(),
));
}
self.analytic_row_hessian_fingerprint = if penalty_fingerprints.is_empty() {
0
} else {
let mut hasher = Fingerprinter::new();
hasher.write_str("arrow-schur-row-hessian-registry-v1");
hasher.write_usize(penalty_fingerprints.len());
for fingerprint in penalty_fingerprints {
hasher.write_u64(fingerprint);
}
hasher.finish_u64()
};
Ok(())
}
pub fn apply_riemannian_latent_geometry(&mut self, latent: &LatentCoordValues) {
let manifold = latent.manifold();
self.manifold_mode_fingerprint = manifold_mode_fingerprint(latent);
if manifold.is_euclidean() {
return;
}
assert_eq!(latent.n_obs(), self.rows.len());
assert_eq!(latent.latent_dim(), self.d);
for (i, row) in self.rows.iter_mut().enumerate() {
let t_i = ArrayView1::from(latent.row(i));
let gt_e = row.gt.clone();
let htt_e = row.htt.clone();
let htbeta_e = row.htbeta.clone();
row.gt = manifold.project_gradient_to_tangent(t_i, gt_e.view());
row.htt = manifold.riemannian_hessian_matrix(t_i, gt_e.view(), htt_e.view());
row.htbeta = manifold.project_matrix_columns_to_gradient_tangent(
t_i,
gt_e.view(),
htbeta_e.view(),
);
}
}
pub(crate) fn add_ext_coord_penalty(
&mut self,
penalty: &AnalyticPenaltyKind,
target_t: ArrayView1<'_, f64>,
rho_local: ArrayView1<'_, f64>,
) {
let d = self.d;
let n = self.rows.len();
apply_analytic_penalty(
penalty,
target_t,
rho_local,
n * d,
d,
self,
|sys, flat, value| sys.rows[flat / d].gt[flat % d] += value,
|sys, flat, value| sys.rows[flat / d].htt[[flat % d, flat % d]] += value,
|a, probe| {
for i in 0..n {
probe[i * d + a] = 1.0;
}
},
|sys, a, hv| {
for i in 0..n {
for b in 0..d {
sys.rows[i].htt[[b, a]] += hv[i * d + b];
}
}
},
);
}
pub(crate) fn add_ext_coord_penalty_gradient_only(
&mut self,
penalty: &AnalyticPenaltyKind,
target_t: ArrayView1<'_, f64>,
rho_local: ArrayView1<'_, f64>,
) {
let d = self.d;
let n = self.rows.len();
assert_eq!(target_t.len(), n * d);
let grad = penalty.grad_target(target_t, rho_local);
for flat in 0..n * d {
self.rows[flat / d].gt[flat % d] += grad[flat];
}
}
pub(crate) fn apply_cross_row_penalty_hessian(
&self,
v: ArrayView1<'_, f64>,
out: &mut Array1<f64>,
) {
for cross in &self.cross_row_penalties {
assert_eq!(cross.target_t.len(), v.len());
let hv =
cross
.penalty
.psd_majorizer_hvp(cross.target_t.view(), cross.rho_local.view(), v);
assert_eq!(hv.len(), out.len());
for i in 0..out.len() {
out[i] += hv[i];
}
}
}
pub(crate) fn add_beta_penalty(
&mut self,
penalty: &AnalyticPenaltyKind,
target_beta: ArrayView1<'_, f64>,
rho_local: ArrayView1<'_, f64>,
) {
let k = self.k;
let hvp_columns = if self.hbb.dim() == (k, k) { k } else { 0 };
apply_analytic_penalty(
penalty,
target_beta,
rho_local,
k,
hvp_columns,
self,
|sys, j, value| sys.gb[j] += value,
|sys, j, value| {
if sys.hbb.dim() == (k, k) {
sys.hbb[[j, j]] += value;
}
if let Some(hbb_diag) = sys.hbb_diag.as_mut() {
hbb_diag[j] += value;
}
},
|j, probe| probe[j] = 1.0,
|sys, j, hv| {
for i in 0..k {
sys.hbb[[i, j]] += hv[i];
}
if let Some(hbb_diag) = sys.hbb_diag.as_mut() {
hbb_diag[j] += hv[j];
}
},
);
}
pub fn solve(
&self,
ridge_t: f64,
ridge_beta: f64,
) -> Result<(Array1<f64>, Array1<f64>, PcgDiagnostics), ArrowSchurError> {
let options = ArrowSolveOptions::automatic(self.k);
solve_arrow_newton_step_core(self, ridge_t, ridge_beta, &options)
}
pub fn solve_with_lm_escalation(
&self,
ridge_t: f64,
ridge_beta: f64,
) -> Result<(Array1<f64>, Array1<f64>, PcgDiagnostics), ArrowSchurError> {
let options = ArrowSolveOptions::automatic(self.k);
solve_with_lm_escalation_inner(self, ridge_t, ridge_beta, &options)
}
pub fn solve_with_options(
&self,
ridge_t: f64,
ridge_beta: f64,
options: &ArrowSolveOptions,
) -> Result<(Array1<f64>, Array1<f64>, PcgDiagnostics), ArrowSchurError> {
solve_arrow_newton_step_core(self, ridge_t, ridge_beta, options)
}
}
pub struct StreamingArrowSchur {
pub n_rows: usize,
pub d: usize,
pub row_dims: Arc<[usize]>,
pub row_offsets: Arc<[usize]>,
pub k: usize,
pub chunk_size: usize,
pub s_acc: Array2<f64>,
pub(crate) rhs_acc: Array1<f64>,
pub(crate) hbb: Array2<f64>,
pub(crate) gb: Array1<f64>,
pub(crate) row_builder: StreamingArrowRowBuilder,
pub(crate) htbeta_matvec: Option<RowHtbetaMatvec>,
pub(crate) htbeta_transpose_matvec: Option<RowHtbetaTransposeMatvec>,
pub(crate) tolerate_ill_conditioning: bool,
pub(crate) ibp_cross_row_active: bool,
pub(crate) row_gauge_deflation: Option<ArrowRowGaugeDeflation>,
pub(crate) ibp_cross_row: Option<IbpCrossRowSource>,
}
#[derive(Debug, Clone)]
pub struct StreamingWoodburyChunk {
pub m0: Array2<f64>,
pub w: Array2<f64>,
pub d: Array1<f64>,
}
impl std::fmt::Debug for StreamingArrowSchur {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("StreamingArrowSchur")
.field("n_rows", &self.n_rows)
.field("d", &self.d)
.field("k", &self.k)
.field("chunk_size", &self.chunk_size)
.finish_non_exhaustive()
}
}
impl StreamingArrowSchur {
#[must_use]
pub fn new(
n_rows: usize,
d: usize,
row_dims: Arc<[usize]>,
row_offsets: Arc<[usize]>,
k: usize,
hbb: Array2<f64>,
gb: Array1<f64>,
row_builder: StreamingArrowRowBuilder,
chunk_size: usize,
) -> Self {
assert_eq!(hbb.dim(), (k, k));
assert_eq!(gb.len(), k);
Self {
n_rows,
d,
row_dims,
row_offsets,
k,
chunk_size: chunk_size.max(1),
s_acc: Array2::<f64>::zeros((k, k)),
rhs_acc: Array1::<f64>::zeros(k),
hbb,
gb,
row_builder,
htbeta_matvec: None,
htbeta_transpose_matvec: None,
tolerate_ill_conditioning: false,
ibp_cross_row_active: false,
row_gauge_deflation: None,
ibp_cross_row: None,
}
}
#[must_use]
pub fn from_system(sys: &ArrowSchurSystem, chunk_size: usize) -> Self {
let htbeta_matvec = sys.htbeta_matvec.clone();
let rows: Vec<ArrowRowBlock> = if htbeta_matvec.is_some() {
sys.rows
.iter()
.map(|row| ArrowRowBlock {
htt: row.htt.clone(),
htbeta: Array2::<f64>::zeros((0, 0)),
gt: row.gt.clone(),
})
.collect()
} else {
sys.rows.clone()
};
let rows = Arc::new(rows);
let row_builder: StreamingArrowRowBuilder = Arc::new(move |row| {
rows.get(row)
.cloned()
.ok_or_else(|| ArrowSchurError::SchurFactorFailed {
reason: format!("streaming row {row} out of bounds"),
})
});
let hbb_dense = sys.effective_penalty_op().to_dense();
let mut streaming = Self::new(
sys.rows.len(),
sys.d,
Arc::clone(&sys.row_dims),
Arc::clone(&sys.row_offsets),
sys.k,
hbb_dense,
sys.gb.clone(),
row_builder,
chunk_size,
);
streaming.htbeta_matvec = htbeta_matvec;
streaming.htbeta_transpose_matvec = sys.htbeta_transpose_matvec.clone();
streaming.ibp_cross_row_active = sys.ibp_cross_row.is_some();
streaming.ibp_cross_row = sys.ibp_cross_row.clone();
streaming.row_gauge_deflation = sys.row_gauge_deflation.clone();
streaming
}
fn factor_row(
&self,
row: &ArrowRowBlock,
ridge_t: f64,
di: usize,
row_idx: usize,
) -> Result<Array2<f64>, ArrowSchurError> {
match self.row_gauge_deflation.as_ref() {
Some(deflation) => factor_one_row_result(
row,
ridge_t,
di,
row_idx,
self.tolerate_ill_conditioning,
deflation.row(row_idx),
true,
)
.map(|result| result.factor),
None => factor_one_row(row, ridge_t, di, row_idx, self.tolerate_ill_conditioning),
}
}
pub(crate) fn row_htbeta(&self, row_idx: usize, row: &ArrowRowBlock, di: usize) -> Array2<f64> {
if let Some(op_t) = self.htbeta_transpose_matvec.as_ref() {
let mut mat = Array2::<f64>::zeros((di, self.k));
let mut e_c = Array1::<f64>::zeros(di);
let mut beta_row = Array1::<f64>::zeros(self.k);
for c in 0..di {
e_c.fill(0.0);
e_c[c] = 1.0;
beta_row.fill(0.0);
op_t(row_idx, e_c.view(), &mut beta_row);
for a in 0..self.k {
mat[[c, a]] = beta_row[a];
}
}
return mat;
}
match self.htbeta_matvec.as_ref() {
Some(op) => {
let mut mat = Array2::<f64>::zeros((di, self.k));
let mut e_a = Array1::<f64>::zeros(self.k);
let mut col = Array1::<f64>::zeros(di);
for a in 0..self.k {
e_a.fill(0.0);
e_a[a] = 1.0;
col.fill(0.0);
op(row_idx, e_a.view(), &mut col);
for c in 0..di {
mat[[c, a]] = col[c];
}
}
mat
}
None => row.htbeta.clone(),
}
}
#[must_use]
pub fn take_accumulators(&mut self) -> (Array2<f64>, Array1<f64>) {
let s = std::mem::replace(&mut self.s_acc, Array2::<f64>::zeros((self.k, self.k)));
let rhs = std::mem::replace(&mut self.rhs_acc, Array1::<f64>::zeros(self.k));
(s, rhs)
}
pub fn reset_accumulator(&mut self, ridge_beta: f64) -> Result<(), ArrowSchurError> {
if self.hbb.dim() != (self.k, self.k) {
return Err(ArrowSchurError::SchurFactorFailed {
reason: "streaming Arrow-Schur requires a dense beta block accumulator".to_string(),
});
}
self.s_acc.assign(&self.hbb);
for j in 0..self.k {
self.s_acc[[j, j]] += ridge_beta;
self.rhs_acc[j] = 0.0;
}
Ok(())
}
pub fn accumulate_chunk(
&mut self,
start: usize,
end: usize,
ridge_t: f64,
mode: ArrowSolverMode,
) -> Result<(), ArrowSchurError> {
if start > end || end > self.n_rows {
return Err(ArrowSchurError::SchurFactorFailed {
reason: format!(
"streaming Arrow-Schur chunk [{start}, {end}) outside 0..{}",
self.n_rows
),
});
}
let backend = CpuBatchedBlockSolver;
let k = self.k;
let parallel = (end - start) >= SCHUR_MATVEC_PARALLEL_ROW_MIN
&& rayon::current_thread_index().is_none();
if parallel {
use rayon::prelude::*;
const CHUNK: usize = 64;
let this: &Self = self;
let row_into = |row_idx: usize,
rhs_part: &mut Array1<f64>,
s_part: &mut Array2<f64>|
-> Result<(), ArrowSchurError> {
let row = (this.row_builder)(row_idx)?;
let di = row.htt.nrows();
this.validate_row(row_idx, &row)?;
let htbeta = this.row_htbeta(row_idx, &row, di);
let factor = this.factor_row(&row, ridge_t, di, row_idx)?;
let v = backend.solve_block_vector(factor.view(), row.gt.view());
for c in 0..di {
let vc = v[c];
if vc == 0.0 {
continue;
}
for a in 0..k {
rhs_part[a] += htbeta[[c, a]] * vc;
}
}
match mode {
ArrowSolverMode::Direct | ArrowSolverMode::InexactPCG => {
let solved = backend.solve_block_matrix(factor.view(), htbeta.view());
backend.block_gemm_subtract(s_part, &htbeta, &solved);
}
ArrowSolverMode::SqrtBA => {
let whitened =
backend.sqrt_solve_block_matrix(factor.view(), htbeta.view());
backend.block_gemm_subtract(s_part, &whitened, &whitened);
}
}
Ok(())
};
let partials: Vec<(Array1<f64>, Array2<f64>)> = (start..end)
.into_par_iter()
.chunks(CHUNK)
.map(|idxs| {
let mut rhs_part = Array1::<f64>::zeros(k);
let mut s_part = Array2::<f64>::zeros((k, k));
for i in idxs {
row_into(i, &mut rhs_part, &mut s_part)?;
}
Ok::<_, ArrowSchurError>((rhs_part, s_part))
})
.collect::<Result<Vec<_>, _>>()?;
for (rhs_part, s_part) in &partials {
for a in 0..k {
self.rhs_acc[a] += rhs_part[a];
}
self.s_acc += s_part;
}
} else {
for row_idx in start..end {
let row = (self.row_builder)(row_idx)?;
let di = row.htt.nrows();
self.validate_row(row_idx, &row)?;
let htbeta = self.row_htbeta(row_idx, &row, di);
let factor = self.factor_row(&row, ridge_t, di, row_idx)?;
let v = backend.solve_block_vector(factor.view(), row.gt.view());
for c in 0..di {
let vc = v[c];
if vc == 0.0 {
continue;
}
for a in 0..k {
self.rhs_acc[a] += htbeta[[c, a]] * vc;
}
}
match mode {
ArrowSolverMode::Direct | ArrowSolverMode::InexactPCG => {
let solved = backend.solve_block_matrix(factor.view(), htbeta.view());
backend.block_gemm_subtract(&mut self.s_acc, &htbeta, &solved);
}
ArrowSolverMode::SqrtBA => {
let whitened =
backend.sqrt_solve_block_matrix(factor.view(), htbeta.view());
backend.block_gemm_subtract(&mut self.s_acc, &whitened, &whitened);
}
}
}
}
Ok(())
}
pub fn reduced_schur_and_log_det_tt(
&mut self,
ridge_t: f64,
ridge_beta: f64,
options: &ArrowSolveOptions,
) -> Result<(f64, Array2<f64>), ArrowSchurError> {
if self.ibp_cross_row_active {
return Err(ArrowSchurError::SchurFactorFailed {
reason: "streaming arrow log-det cannot carry the exact cross-row IBP \
Woodbury correction (#1038): U's columns span all rows, so the \
rank-R capacitance needs the per-row factors retained — the very \
(N·K) residency the streaming path avoids. Route IBP-active fits \
through the dense resident ArrowFactorCache::arrow_log_det instead."
.to_string(),
});
}
self.tolerate_ill_conditioning = options.tolerate_ill_conditioning;
self.reset_accumulator(ridge_beta)?;
let backend = CpuBatchedBlockSolver;
let mut log_det_tt = 0.0_f64;
for start in (0..self.n_rows).step_by(self.chunk_size) {
let end = (start + self.chunk_size).min(self.n_rows);
for row_idx in start..end {
let row = (self.row_builder)(row_idx)?;
let di = row.htt.nrows();
self.validate_row(row_idx, &row)?;
let htbeta = self.row_htbeta(row_idx, &row, di);
let factor = self.factor_row(&row, ridge_t, di, row_idx)?;
for axis in 0..di {
log_det_tt += 2.0 * factor[[axis, axis]].ln();
}
match options.mode {
ArrowSolverMode::Direct | ArrowSolverMode::InexactPCG => {
let solved = backend.solve_block_matrix(factor.view(), htbeta.view());
backend.block_gemm_subtract(&mut self.s_acc, &htbeta, &solved);
}
ArrowSolverMode::SqrtBA => {
let whitened =
backend.sqrt_solve_block_matrix(factor.view(), htbeta.view());
backend.block_gemm_subtract(&mut self.s_acc, &whitened, &whitened);
}
}
}
}
symmetrize_upper_from_lower(&mut self.s_acc);
let schur = std::mem::replace(&mut self.s_acc, Array2::<f64>::zeros((self.k, self.k)));
Ok((log_det_tt, schur))
}
pub fn reduced_schur_log_det_tt_woodbury(
&mut self,
ridge_t: f64,
ridge_beta: f64,
options: &ArrowSolveOptions,
) -> Result<(f64, Array2<f64>, Option<StreamingWoodburyChunk>), ArrowSchurError> {
let Some(source) = self.ibp_cross_row.clone() else {
let (log_det_tt, schur) =
self.reduced_schur_and_log_det_tt(ridge_t, ridge_beta, options)?;
return Ok((log_det_tt, schur, None));
};
let r = source.r;
let total_len = self.row_offsets[self.n_rows];
let down = source.self_term_downdate(total_len);
let mut row_entries: Vec<Vec<(usize, usize, f64)>> = vec![Vec::new(); self.n_rows];
for &(g, atom, z) in &source.entries {
let i = match self.row_offsets.binary_search(&g) {
Ok(idx) => idx,
Err(idx) => idx - 1,
};
let slot = g - self.row_offsets[i];
row_entries[i].push((slot, atom, z));
}
self.tolerate_ill_conditioning = options.tolerate_ill_conditioning;
self.reset_accumulator(ridge_beta)?;
let backend = CpuBatchedBlockSolver;
let mut log_det_tt = 0.0_f64;
let mut m0 = Array2::<f64>::zeros((r, r));
let mut w = Array2::<f64>::zeros((self.k, r));
for start in (0..self.n_rows).step_by(self.chunk_size) {
let end = (start + self.chunk_size).min(self.n_rows);
for row_idx in start..end {
let mut row = (self.row_builder)(row_idx)?;
let di = row.htt.nrows();
self.validate_row(row_idx, &row)?;
let base = self.row_offsets[row_idx];
for j in 0..di {
row.htt[[j, j]] -= down[base + j];
}
let htbeta = self.row_htbeta(row_idx, &row, di);
let factor = self.factor_row(&row, ridge_t, di, row_idx)?;
for axis in 0..di {
log_det_tt += 2.0 * factor[[axis, axis]].ln();
}
match options.mode {
ArrowSolverMode::Direct | ArrowSolverMode::InexactPCG => {
let solved = backend.solve_block_matrix(factor.view(), htbeta.view());
backend.block_gemm_subtract(&mut self.s_acc, &htbeta, &solved);
}
ArrowSolverMode::SqrtBA => {
let whitened =
backend.sqrt_solve_block_matrix(factor.view(), htbeta.view());
backend.block_gemm_subtract(&mut self.s_acc, &whitened, &whitened);
}
}
let entries = &row_entries[row_idx];
if !entries.is_empty() {
let mut u_local = Array2::<f64>::zeros((di, r));
for &(slot, atom, z) in entries {
u_local[[slot, atom]] += z;
}
let ainv_u = backend.solve_block_matrix(factor.view(), u_local.view());
m0 += &u_local.t().dot(&ainv_u);
w += &htbeta.t().dot(&ainv_u);
}
}
}
symmetrize_upper_from_lower(&mut self.s_acc);
let schur = std::mem::replace(&mut self.s_acc, Array2::<f64>::zeros((self.k, self.k)));
Ok((
log_det_tt,
schur,
Some(StreamingWoodburyChunk {
m0,
w,
d: source.d.clone(),
}),
))
}
pub fn reduced_schur_log_det(
schur: &Array2<f64>,
options: &ArrowSolveOptions,
) -> Result<f64, ArrowSchurError> {
let rhs = Array1::<f64>::zeros(schur.nrows());
let trust_metric_weights = None;
let (delta, schur_factor, diag) =
solve_dense_reduced_system(schur, &rhs, options, trust_metric_weights)?;
if delta.len() != schur.nrows() || diag.iterations != 0 {
return Err(ArrowSchurError::SchurFactorFailed {
reason: "streaming log-det reduced solve returned incoherent diagnostics"
.to_string(),
});
}
let schur_factor = schur_factor.ok_or_else(|| ArrowSchurError::SchurFactorFailed {
reason: "streaming log-det requires a dense reduced Schur factor".to_string(),
})?;
let mut log_det_schur = 0.0_f64;
for axis in 0..schur_factor.nrows() {
log_det_schur += 2.0 * schur_factor[[axis, axis]].ln();
}
Ok(log_det_schur)
}
pub fn exact_arrow_log_det(
&mut self,
ridge_t: f64,
ridge_beta: f64,
options: &ArrowSolveOptions,
) -> Result<f64, ArrowSchurError> {
let (log_det_tt, schur) =
self.reduced_schur_and_log_det_tt(ridge_t, ridge_beta, options)?;
Ok(log_det_tt + Self::reduced_schur_log_det(&schur, options)?)
}
pub fn solve(
&mut self,
ridge_t: f64,
ridge_beta: f64,
options: &ArrowSolveOptions,
) -> Result<(Array1<f64>, Array1<f64>, Option<Array2<f64>>), ArrowSchurError> {
if self.ibp_cross_row_active {
return Err(ArrowSchurError::SchurFactorFailed {
reason: "streaming arrow solve cannot carry the exact cross-row IBP \
Woodbury correction (#1038); route IBP-active fits through the \
dense resident solve_arrow_newton_step_with_options instead."
.to_string(),
});
}
self.tolerate_ill_conditioning = options.tolerate_ill_conditioning;
self.reset_accumulator(ridge_beta)?;
for start in (0..self.n_rows).step_by(self.chunk_size) {
let end = (start + self.chunk_size).min(self.n_rows);
self.accumulate_chunk(start, end, ridge_t, options.mode)?;
}
for j in 0..self.k {
self.rhs_acc[j] -= self.gb[j];
}
symmetrize_upper_from_lower(&mut self.s_acc);
let trust_metric_weights = None;
let (delta_beta, schur_factor, _diag) =
solve_dense_reduced_system(&self.s_acc, &self.rhs_acc, options, trust_metric_weights)?;
let delta_t = self.back_substitute(ridge_t, delta_beta.view())?;
Ok((delta_t, delta_beta, schur_factor))
}
pub(crate) fn back_substitute(
&self,
ridge_t: f64,
delta_beta: ArrayView1<'_, f64>,
) -> Result<Array1<f64>, ArrowSchurError> {
let backend = CpuBatchedBlockSolver;
let total_len = self.row_offsets[self.n_rows];
let mut delta_t = Array1::<f64>::zeros(total_len);
let parallel =
self.n_rows >= SCHUR_MATVEC_PARALLEL_ROW_MIN && rayon::current_thread_index().is_none();
if parallel {
use rayon::prelude::*;
const CHUNK: usize = 64;
let row_solve = |row_idx: usize| -> Result<(usize, Array1<f64>), ArrowSchurError> {
let row = (self.row_builder)(row_idx)?;
let di = row.htt.nrows();
self.validate_row(row_idx, &row)?;
let factor = self.factor_row(&row, ridge_t, di, row_idx)?;
let mut htbeta_delta = Array1::<f64>::zeros(di);
if let Some(op) = self.htbeta_matvec.as_ref() {
op(row_idx, delta_beta, &mut htbeta_delta);
} else {
for c in 0..di {
let mut acc = 0.0_f64;
for a in 0..self.k {
acc += row.htbeta[[c, a]] * delta_beta[a];
}
htbeta_delta[c] = acc;
}
}
let mut rhs = Array1::<f64>::zeros(di);
for c in 0..di {
rhs[c] = row.gt[c] + htbeta_delta[c];
}
let dt_i = backend.solve_block_vector(factor.view(), rhs.view());
let mut neg = Array1::<f64>::zeros(di);
for c in 0..di {
neg[c] = -dt_i[c];
}
Ok((self.row_offsets[row_idx], neg))
};
let segments: Vec<(usize, Array1<f64>)> = (0..self.n_rows)
.into_par_iter()
.chunks(CHUNK)
.map(|idxs| {
idxs.into_iter()
.map(&row_solve)
.collect::<Result<Vec<_>, _>>()
})
.collect::<Result<Vec<_>, _>>()?
.into_iter()
.flatten()
.collect();
for (base, seg) in &segments {
for (c, &v) in seg.iter().enumerate() {
delta_t[base + c] = v;
}
}
} else {
let mut rhs = Array1::<f64>::zeros(self.d);
for start in (0..self.n_rows).step_by(self.chunk_size) {
let end = (start + self.chunk_size).min(self.n_rows);
for row_idx in start..end {
let row = (self.row_builder)(row_idx)?;
let di = row.htt.nrows();
self.validate_row(row_idx, &row)?;
let factor = self.factor_row(&row, ridge_t, di, row_idx)?;
let mut htbeta_delta = Array1::<f64>::zeros(di);
if let Some(op) = self.htbeta_matvec.as_ref() {
op(row_idx, delta_beta, &mut htbeta_delta);
} else {
for c in 0..di {
let mut acc = 0.0_f64;
for a in 0..self.k {
acc += row.htbeta[[c, a]] * delta_beta[a];
}
htbeta_delta[c] = acc;
}
}
for c in 0..di {
rhs[c] = row.gt[c] + htbeta_delta[c];
}
let dt_i = backend.solve_block_vector(factor.view(), rhs.view());
let row_base = self.row_offsets[row_idx];
for c in 0..di {
delta_t[row_base + c] = -dt_i[c];
}
}
}
}
Ok(delta_t)
}
pub(crate) fn validate_row(
&self,
row_idx: usize,
row: &ArrowRowBlock,
) -> Result<(), ArrowSchurError> {
let expected_di = if row_idx < self.row_dims.len() {
self.row_dims[row_idx]
} else {
self.d
};
let actual_di = row.htt.nrows();
if actual_di != expected_di || row.htt.ncols() != expected_di {
return Err(ArrowSchurError::PerRowFactorFailed {
row: row_idx,
reason: format!(
"streaming row H_tt shape {:?} != ({expected_di}, {expected_di})",
row.htt.dim(),
),
});
}
if self.htbeta_matvec.is_none() && row.htbeta.dim() != (expected_di, self.k) {
return Err(ArrowSchurError::SchurFactorFailed {
reason: format!(
"streaming row H_tβ shape {:?} != ({expected_di}, {})",
row.htbeta.dim(),
self.k
),
});
}
if row.gt.len() != expected_di {
return Err(ArrowSchurError::PerRowFactorFailed {
row: row_idx,
reason: format!("streaming row g_t length {} != {expected_di}", row.gt.len()),
});
}
Ok::<(), _>(())
}
}
pub(crate) fn apply_analytic_penalty<S, G, D, P, H>(
penalty: &AnalyticPenaltyKind,
target: ArrayView1<'_, f64>,
rho_local: ArrayView1<'_, f64>,
expected_target_len: usize,
hvp_columns: usize,
scatter_target: &mut S,
mut grad_scatter: G,
mut diag_scatter: D,
seed_hvp_probe: P,
mut hvp_column_scatter: H,
) where
G: FnMut(&mut S, usize, f64),
D: FnMut(&mut S, usize, f64),
P: Fn(usize, &mut Array1<f64>),
H: for<'a> FnMut(&mut S, usize, ArrayView1<'a, f64>),
{
assert_eq!(target.len(), expected_target_len);
let grad = penalty.grad_target(target, rho_local);
for index in 0..expected_target_len {
grad_scatter(scatter_target, index, grad[index]);
}
if let Some(diag) = penalty.psd_majorizer_diag(target, rho_local) {
assert_eq!(diag.len(), expected_target_len);
for index in 0..expected_target_len {
diag_scatter(scatter_target, index, diag[index]);
}
return;
}
let mut probe = Array1::<f64>::zeros(expected_target_len);
for column in 0..hvp_columns {
probe.fill(0.0);
seed_hvp_probe(column, &mut probe);
let hv = penalty.psd_majorizer_hvp(target, rho_local, probe.view());
hvp_column_scatter(scatter_target, column, hv.view());
}
}
pub(crate) fn analytic_penalty_is_row_block_diagonal(penalty: &AnalyticPenaltyKind) -> bool {
penalty.is_row_block_diagonal()
}
#[derive(Clone)]
pub struct ArrowFactorSlab {
pub(crate) data: Arc<[f64]>,
pub(crate) offsets: Arc<[usize]>,
pub(crate) dims: Arc<[usize]>,
}
impl ArrowFactorSlab {
pub fn from_blocks(blocks: Vec<Array2<f64>>) -> Self {
let mut data = Vec::new();
let mut offsets = Vec::with_capacity(blocks.len() + 1);
let mut dims = Vec::with_capacity(blocks.len());
offsets.push(0);
for block in blocks {
let (rows, cols) = block.dim();
assert_eq!(rows, cols, "ArrowFactorSlab stores square row factors");
dims.push(rows);
data.extend(block.iter().copied());
offsets.push(data.len());
}
Self {
data: data.into(),
offsets: offsets.into(),
dims: dims.into(),
}
}
pub fn len(&self) -> usize {
self.dims.len()
}
pub fn is_empty(&self) -> bool {
self.dims.is_empty()
}
pub fn factor(&self, row: usize) -> ArrayView2<'_, f64> {
let dim = self.dims[row];
let range = self.offsets[row]..self.offsets[row + 1];
ArrayView2::from_shape((dim, dim), &self.data[range])
.expect("ArrowFactorSlab row offset/dim invariant violated")
}
pub fn iter(&self) -> impl Iterator<Item = ArrayView2<'_, f64>> + '_ {
(0..self.len()).map(|row| self.factor(row))
}
}
impl std::fmt::Debug for ArrowFactorSlab {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ArrowFactorSlab")
.field("rows", &self.len())
.field("values", &self.data.len())
.finish()
}
}
#[derive(Clone)]
pub enum ArrowUndampedFactors {
SameAsDamped,
Owned(ArrowFactorSlab),
}
impl std::fmt::Debug for ArrowUndampedFactors {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::SameAsDamped => f.write_str("SameAsDamped"),
Self::Owned(factors) => f.debug_tuple("Owned").field(&factors.len()).finish(),
}
}
}
pub(crate) fn sys_htbeta_apply_row(
sys: &ArrowSchurSystem,
row_idx: usize,
row: &ArrowRowBlock,
x: ArrayView1<'_, f64>,
out: &mut Array1<f64>,
) {
out.fill(0.0);
if let Some(op) = sys.htbeta_matvec.as_ref() {
op(row_idx, x, out);
}
if (sys.htbeta_dense_supplement || sys.htbeta_matvec.is_none())
&& row.htbeta.dim() == (out.len(), sys.k)
{
let di = row.htbeta.nrows();
for c in 0..di {
let mut acc = 0.0_f64;
for a in 0..sys.k {
acc += row.htbeta[[c, a]] * x[a];
}
out[c] += acc;
}
}
}
pub(crate) fn sys_htbeta_accumulate_transpose(
sys: &ArrowSchurSystem,
row_idx: usize,
row: &ArrowRowBlock,
v: ArrayView1<'_, f64>,
out: &mut Array1<f64>,
) {
if let Some(op) = sys.htbeta_matvec.as_ref() {
htbeta_probe_transpose(row_idx, op, v, out, v.len(), sys.k);
}
if (sys.htbeta_dense_supplement || sys.htbeta_matvec.is_none())
&& row.htbeta.dim() == (v.len(), sys.k)
{
let di = row.htbeta.nrows();
for c in 0..di {
let vc = v[c];
if vc == 0.0 {
continue;
}
for a in 0..sys.k {
out[a] += row.htbeta[[c, a]] * vc;
}
}
}
}
pub(crate) fn sys_htbeta_materialize_row(
sys: &ArrowSchurSystem,
row_idx: usize,
row: &ArrowRowBlock,
) -> Result<Array2<f64>, ArrowSchurError> {
let di = sys.row_dims[row_idx];
let k = sys.k;
let use_dense = sys.htbeta_dense_supplement || sys.htbeta_matvec.is_none();
let mut mat = if use_dense && row.htbeta.dim() == (di, k) {
row.htbeta.clone()
} else {
Array2::<f64>::zeros((di, k))
};
if let Some(op) = sys.htbeta_matvec.as_ref() {
let mut e_a = Array1::<f64>::zeros(k);
let mut col = Array1::<f64>::zeros(di);
for a in 0..k {
e_a.fill(0.0);
e_a[a] = 1.0;
col.fill(0.0);
op(row_idx, e_a.view(), &mut col);
for c in 0..di {
mat[[c, a]] += col[c];
}
}
} else if use_dense && row.htbeta.dim() != (di, k) {
return Err(ArrowSchurError::SchurFactorFailed {
reason: format!(
"row {row_idx}: htbeta shape {:?} != ({di}, {k}) and no htbeta_matvec installed",
row.htbeta.dim()
),
});
}
Ok(mat)
}
pub(crate) fn htbeta_probe_transpose(
row: usize,
op: &RowHtbetaMatvec,
v: ArrayView1<'_, f64>,
out: &mut Array1<f64>,
d: usize,
k: usize,
) {
let mut e_a = Array1::<f64>::zeros(k);
let mut col_a = Array1::<f64>::zeros(d);
for a in 0..k {
e_a.fill(0.0);
e_a[a] = 1.0;
col_a.fill(0.0);
op(row, e_a.view(), &mut col_a);
let mut acc = 0.0_f64;
for c in 0..d {
acc += col_a[c] * v[c];
}
out[a] += acc;
}
}
#[derive(Clone)]
pub enum ArrowHtbetaCache {
Dense {
blocks: Arc<[Array2<f64>]>,
estimated_bytes: usize,
},
Matvec {
op: RowHtbetaMatvec,
estimated_bytes: usize,
},
Disabled {
estimated_bytes: usize,
},
}
impl std::fmt::Debug for ArrowHtbetaCache {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Dense {
blocks,
estimated_bytes,
} => f
.debug_struct("Dense")
.field("blocks", &blocks.len())
.field("estimated_bytes", estimated_bytes)
.finish(),
Self::Matvec {
estimated_bytes, ..
} => f
.debug_struct("Matvec")
.field("estimated_bytes", estimated_bytes)
.finish(),
Self::Disabled { estimated_bytes } => f
.debug_struct("Disabled")
.field("estimated_bytes", estimated_bytes)
.finish(),
}
}
}
impl ArrowHtbetaCache {
pub(crate) fn is_available(&self) -> bool {
!matches!(self, Self::Disabled { .. })
}
pub(crate) fn apply_row(
&self,
row: usize,
delta_beta: ArrayView1<'_, f64>,
out: &mut Array1<f64>,
) -> bool {
match self {
Self::Dense { blocks, .. } => {
let Some(block) = blocks.get(row) else {
return false;
};
if block.ncols() != delta_beta.len() || block.nrows() != out.len() {
return false;
}
for c in 0..block.nrows() {
let mut acc = 0.0_f64;
for a in 0..block.ncols() {
acc += block[[c, a]] * delta_beta[a];
}
out[c] = acc;
}
true
}
Self::Matvec { op, .. } => {
op(row, delta_beta, out);
true
}
Self::Disabled { .. } => false,
}
}
pub(crate) fn apply_row_transpose_accumulate(
&self,
row: usize,
v: ArrayView1<'_, f64>,
out: &mut Array1<f64>,
d: usize,
k: usize,
fallback_op: Option<&RowHtbetaMatvec>,
) -> bool {
match self {
Self::Dense { blocks, .. } => {
let Some(block) = blocks.get(row) else {
return false;
};
if block.nrows() != v.len() || block.ncols() != out.len() {
return false;
}
for c in 0..block.nrows() {
let vc = v[c];
if vc == 0.0 {
continue;
}
for a in 0..block.ncols() {
out[a] += block[[c, a]] * vc;
}
}
true
}
Self::Matvec { op, .. } => {
htbeta_probe_transpose(row, op, v, out, d, k);
true
}
Self::Disabled { .. } => {
if let Some(op) = fallback_op {
htbeta_probe_transpose(row, op, v, out, d, k);
true
} else {
false
}
}
}
}
}
#[derive(Debug, Clone)]
pub struct RowDeflationSpectrum {
pub evecs: Array2<f64>,
pub raw_evals: Array1<f64>,
pub cond_evals: Array1<f64>,
}
#[derive(Debug, Clone)]
pub struct ArrowFactorCache {
pub htt_factors: ArrowFactorSlab,
pub htt_factors_undamped: ArrowUndampedFactors,
pub schur_factor: Option<Array2<f64>>,
pub joint_hessian_log_det: Option<f64>,
pub solver_mode: ArrowSolverMode,
pub ridge_t: f64,
pub ridge_beta: f64,
pub htbeta: ArrowHtbetaCache,
pub d: usize,
pub row_dims: Arc<[usize]>,
pub row_offsets: Arc<[usize]>,
pub k: usize,
pub manifold_mode_fingerprint: u64,
pub row_hessian_fingerprint: u64,
pub pcg_diagnostics: PcgDiagnostics,
pub gauge_deflated_directions: usize,
pub deflated_row_directions: Arc<[Vec<Array1<f64>>]>,
pub deflation_row_spectra: Arc<[Option<RowDeflationSpectrum>]>,
pub cross_row_woodbury: Option<CrossRowWoodbury>,
}
#[derive(Debug, Clone)]
pub struct CrossRowWoodbury {
pub u: Array2<f64>,
pub d: Array1<f64>,
pub h0inv_u: Array2<f64>,
pub h0inv_u_beta: Array2<f64>,
pub m: Array2<f64>,
pub capacitance_lu: SmallLu,
pub entries: Vec<(usize, usize, f64)>,
}
#[derive(Debug, Clone)]
pub struct SmallLu {
pub(crate) lu: Array2<f64>,
pub(crate) piv: Vec<usize>,
pub(crate) perm_sign: f64,
}
pub(crate) fn small_lu_factor(a: &Array2<f64>) -> Option<SmallLu> {
let r = a.nrows();
assert_eq!(a.ncols(), r, "small_lu_factor: non-square input");
let mut lu = a.clone();
let mut piv: Vec<usize> = (0..r).collect();
let mut perm_sign = 1.0_f64;
for col in 0..r {
let mut pivot_row = col;
let mut pivot_mag = lu[[col, col]].abs();
for row in (col + 1)..r {
let mag = lu[[row, col]].abs();
if mag > pivot_mag {
pivot_mag = mag;
pivot_row = row;
}
}
if !pivot_mag.is_finite() || pivot_mag < f64::MIN_POSITIVE {
return None;
}
if pivot_row != col {
for c in 0..r {
lu.swap((col, c), (pivot_row, c));
}
piv.swap(col, pivot_row);
perm_sign = -perm_sign;
}
let pivot = lu[[col, col]];
for row in (col + 1)..r {
let factor = lu[[row, col]] / pivot;
lu[[row, col]] = factor;
for c in (col + 1)..r {
let v = lu[[col, c]];
lu[[row, c]] -= factor * v;
}
}
}
for i in 0..r {
let u = lu[[i, i]];
if !u.is_finite() || u.abs() < f64::MIN_POSITIVE {
return None;
}
}
Some(SmallLu { lu, piv, perm_sign })
}
impl SmallLu {
pub(crate) fn dim(&self) -> usize {
self.lu.nrows()
}
pub(crate) fn log_abs_det_and_sign(&self) -> (f64, f64) {
let mut log_abs = 0.0_f64;
let mut sign = self.perm_sign;
for i in 0..self.dim() {
let u = self.lu[[i, i]];
log_abs += u.abs().ln();
if u < 0.0 {
sign = -sign;
}
}
(log_abs, sign)
}
pub(crate) fn solve(&self, b: &Array1<f64>) -> Option<Array1<f64>> {
let r = self.dim();
let mut y = Array1::<f64>::zeros(r);
for i in 0..r {
y[i] = b[self.piv[i]];
}
for i in 0..r {
let mut sum = y[i];
for j in 0..i {
sum -= self.lu[[i, j]] * y[j];
}
y[i] = sum;
}
let mut x = Array1::<f64>::zeros(r);
for i in (0..r).rev() {
let mut sum = y[i];
for j in (i + 1)..r {
sum -= self.lu[[i, j]] * x[j];
}
let pivot = self.lu[[i, i]];
if !pivot.is_finite() || pivot.abs() < f64::MIN_POSITIVE {
return None;
}
x[i] = sum / pivot;
}
if x.iter().all(|v| v.is_finite()) {
Some(x)
} else {
None
}
}
}
pub fn streaming_cross_row_woodbury_log_det(
schur: &Array2<f64>,
m0: &Array2<f64>,
w: &Array2<f64>,
d: &Array1<f64>,
) -> Result<Option<f64>, ArrowSchurError> {
let r = d.len();
let factor =
cholesky_lower(schur).map_err(|reason| ArrowSchurError::SchurFactorFailed { reason })?;
let mut m = m0.clone();
for a in 0..r {
let w_a = w.column(a).to_owned();
let sinv_w_a = cholesky_solve_vector(&factor, &w_a);
for b in 0..r {
m[[a, b]] += sinv_w_a.dot(&w.column(b));
}
}
for a in 0..r {
for b in (a + 1)..r {
let avg = 0.5 * (m[[a, b]] + m[[b, a]]);
m[[a, b]] = avg;
m[[b, a]] = avg;
}
}
let mut c = Array2::<f64>::zeros((r, r));
for a in 0..r {
for b in 0..r {
c[[a, b]] = d[a] * m[[a, b]];
}
c[[a, a]] += 1.0;
}
match small_lu_factor(&c) {
Some(lu) => {
let (log_abs, sign) = lu.log_abs_det_and_sign();
Ok((sign > 0.0).then_some(log_abs))
}
None => Ok(None),
}
}
impl CrossRowWoodbury {
pub(crate) fn build(
cache: &ArrowFactorCache,
source: &IbpCrossRowSource,
) -> Result<Option<Self>, ArrowSchurError> {
let r = source.r;
let total_len = cache.delta_t_len();
let u = source.dense_u(total_len);
let d = source.d.clone();
let zero_beta = Array1::<f64>::zeros(cache.k);
let mut h0inv_u = Array2::<f64>::zeros((total_len, r));
let mut h0inv_u_beta = Array2::<f64>::zeros((cache.k, r));
for k in 0..r {
let col = u.column(k).to_owned();
let (sol_t, sol_beta) = cache.full_inverse_apply(col.view(), zero_beta.view())?;
for g in 0..total_len {
h0inv_u[[g, k]] = sol_t[g];
}
for c in 0..cache.k {
h0inv_u_beta[[c, k]] = sol_beta[c];
}
}
let mut m = Array2::<f64>::zeros((r, r));
for a in 0..r {
for b in 0..r {
let mut acc = 0.0_f64;
for &(g, k, z) in &source.entries {
if k == a {
acc += z * h0inv_u[[g, b]];
}
}
m[[a, b]] = acc;
}
}
for a in 0..r {
for b in (a + 1)..r {
let avg = 0.5 * (m[[a, b]] + m[[b, a]]);
m[[a, b]] = avg;
m[[b, a]] = avg;
}
}
let mut c = Array2::<f64>::zeros((r, r));
for a in 0..r {
for b in 0..r {
c[[a, b]] = d[a] * m[[a, b]];
}
c[[a, a]] += 1.0;
}
let Some(capacitance_lu) = small_lu_factor(&c) else {
return Ok(None);
};
Ok(Some(Self {
u,
d,
h0inv_u,
h0inv_u_beta,
m,
capacitance_lu,
entries: source.entries.clone(),
}))
}
pub(crate) fn source_entries(&self) -> &[(usize, usize, f64)] {
&self.entries
}
pub fn capacitance_inv_times_d(&self) -> Option<Array2<f64>> {
let r = self.d.len();
let mut out = Array2::<f64>::zeros((r, r));
let mut e_l = Array1::<f64>::zeros(r);
for l in 0..r {
e_l.fill(0.0);
e_l[l] = 1.0;
let col = self.capacitance_lu.solve(&e_l)?;
for k in 0..r {
out[[k, l]] = col[k] * self.d[l];
}
}
Some(out)
}
pub(crate) fn subtract_inverse_diagonal(
&self,
diag: &mut Array1<f64>,
) -> Result<(), ArrowSchurError> {
let r = self.d.len();
let cinv_d =
self.capacitance_inv_times_d()
.ok_or_else(|| ArrowSchurError::SchurFactorFailed {
reason: "cross-row Woodbury capacitance solve produced a non-finite \
C⁻¹D for the inverse-diagonal correction (#1038): \
singular/ill-conditioned cross-row capacitance"
.to_string(),
})?;
let total_len = self.h0inv_u.nrows();
for g in 0..total_len {
let mut acc = 0.0_f64;
for k in 0..r {
let gk = self.h0inv_u[[g, k]];
if gk == 0.0 {
continue;
}
for l in 0..r {
acc += gk * cinv_d[[k, l]] * self.h0inv_u[[g, l]];
}
}
diag[g] -= acc;
}
Ok(())
}
pub fn log_det(&self) -> Option<f64> {
let (log_abs, sign) = self.log_det_correction();
if sign > 0.0 { Some(log_abs) } else { None }
}
pub(crate) fn log_det_correction(&self) -> (f64, f64) {
self.capacitance_lu.log_abs_det_and_sign()
}
pub(crate) fn apply_inverse_correction(
&self,
h0inv_rhs_t: ArrayView1<'_, f64>,
entries: &[(usize, usize, f64)],
u_t: &mut Array1<f64>,
u_beta: &mut Array1<f64>,
) -> Result<(), ArrowSchurError> {
let r = self.d.len();
let mut p = Array1::<f64>::zeros(r);
for &(g, k, z) in entries {
p[k] += z * h0inv_rhs_t[g];
}
for k in 0..r {
p[k] *= self.d[k];
}
let q =
self.capacitance_lu
.solve(&p)
.ok_or_else(|| ArrowSchurError::SchurFactorFailed {
reason: "cross-row Woodbury capacitance solve produced a non-finite \
C⁻¹p for the inverse correction (#1038): \
singular/ill-conditioned cross-row capacitance"
.to_string(),
})?;
for g in 0..u_t.len() {
let mut acc = 0.0_f64;
for k in 0..r {
acc += self.h0inv_u[[g, k]] * q[k];
}
u_t[g] -= acc;
}
for c in 0..u_beta.len() {
let mut acc = 0.0_f64;
for k in 0..r {
acc += self.h0inv_u_beta[[c, k]] * q[k];
}
u_beta[c] -= acc;
}
Ok(())
}
pub fn apply_forward_t(&self, v_t: ArrayView1<'_, f64>, out_t: &mut Array1<f64>) {
let r = self.d.len();
let mut p = Array1::<f64>::zeros(r);
for &(g, k, z) in &self.entries {
p[k] += z * v_t[g];
}
for k in 0..r {
p[k] *= self.d[k];
}
for &(g, k, z) in &self.entries {
out_t[g] += z * p[k];
}
}
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub struct ArrowFactorMinPivot {
pub min_row_pivot: Option<f64>,
pub min_schur_pivot: Option<f64>,
pub min_pivot: Option<f64>,
}
impl ArrowFactorMinPivot {
pub(crate) fn combine(row: Option<f64>, schur: Option<f64>) -> Self {
let min_pivot = match (row, schur) {
(Some(a), Some(b)) => Some(a.min(b)),
(Some(a), None) => Some(a),
(None, Some(b)) => Some(b),
(None, None) => None,
};
Self {
min_row_pivot: row,
min_schur_pivot: schur,
min_pivot,
}
}
}
pub(crate) fn lower_cholesky_min_pivot(factor: ArrayView2<'_, f64>) -> Option<f64> {
let width = factor.nrows().min(factor.ncols());
let mut out = None;
for idx in 0..width {
let pivot = factor[[idx, idx]] * factor[[idx, idx]];
out = Some(match out {
Some(current) => f64::min(current, pivot),
None => pivot,
});
}
out
}
pub(crate) fn lower_cholesky_max_pivot(factor: ArrayView2<'_, f64>) -> Option<f64> {
let width = factor.nrows().min(factor.ncols());
let mut out = None;
for idx in 0..width {
let pivot = factor[[idx, idx]] * factor[[idx, idx]];
out = Some(match out {
Some(current) => f64::max(current, pivot),
None => pivot,
});
}
out
}
pub fn arrow_factor_min_pivot(cache: &ArrowFactorCache) -> ArrowFactorMinPivot {
let mut min_row_pivot = None;
for factor in cache.htt_factors.iter() {
if let Some(pivot) = lower_cholesky_min_pivot(factor) {
min_row_pivot = Some(match min_row_pivot {
Some(current) => f64::min(current, pivot),
None => pivot,
});
}
}
let min_schur_pivot = cache
.schur_factor
.as_ref()
.and_then(|factor| lower_cholesky_min_pivot(factor.view()));
ArrowFactorMinPivot::combine(min_row_pivot, min_schur_pivot)
}
pub fn arrow_factor_max_pivot(cache: &ArrowFactorCache) -> Option<f64> {
let mut max_pivot: Option<f64> = None;
for factor in cache.htt_factors.iter() {
if let Some(pivot) = lower_cholesky_max_pivot(factor) {
max_pivot = Some(match max_pivot {
Some(current) => f64::max(current, pivot),
None => pivot,
});
}
}
if let Some(factor) = cache.schur_factor.as_ref()
&& let Some(pivot) = lower_cholesky_max_pivot(factor.view())
{
max_pivot = Some(match max_pivot {
Some(current) => f64::max(current, pivot),
None => pivot,
});
}
max_pivot
}
impl ArrowFactorCache {
pub fn n_rows(&self) -> usize {
self.htt_factors.len()
}
pub fn htbeta_available(&self) -> bool {
self.htbeta.is_available()
}
#[must_use]
pub fn used_device(&self) -> bool {
self.pcg_diagnostics.used_device_arrow
}
pub fn undamped_factor(&self, row: usize) -> ArrayView2<'_, f64> {
match &self.htt_factors_undamped {
ArrowUndampedFactors::SameAsDamped => self.htt_factors.factor(row),
ArrowUndampedFactors::Owned(factors) => factors.factor(row),
}
}
pub fn undamped_factor_count(&self) -> usize {
match &self.htt_factors_undamped {
ArrowUndampedFactors::SameAsDamped => self.htt_factors.len(),
ArrowUndampedFactors::Owned(factors) => factors.len(),
}
}
pub fn undamped_factors_iter(&self) -> impl Iterator<Item = ArrayView2<'_, f64>> + '_ {
(0..self.undamped_factor_count()).map(|row| self.undamped_factor(row))
}
pub fn compute_undamped_arrow_log_det(&self) -> Option<f64> {
if self.ridge_t != 0.0 || self.ridge_beta != 0.0 {
return None;
}
let schur = match self.schur_factor.as_ref() {
Some(schur) => Some(schur),
None if self.k == 0 => None,
None => return None,
};
let mut acc = 0.0_f64;
for l in self.undamped_factors_iter() {
for i in 0..l.nrows() {
let d = l[[i, i]];
if d <= 0.0 || !d.is_finite() {
return None;
}
acc += 2.0 * d.ln();
}
}
if let Some(schur) = schur {
for i in 0..schur.nrows() {
let d = schur[[i, i]];
if d <= 0.0 || !d.is_finite() {
return None;
}
acc += 2.0 * d.ln();
}
}
let woodbury_correction = self.cross_row_woodbury_log_det();
if !woodbury_correction.is_finite() {
return None;
}
Some(acc + woodbury_correction)
}
pub fn delta_t_len(&self) -> usize {
self.row_offsets[self.n_rows()]
}
pub fn apply_htbeta_row(
&self,
row: usize,
delta_beta: ArrayView1<'_, f64>,
out: &mut Array1<f64>,
) -> bool {
let di = if row < self.row_dims.len() {
self.row_dims[row]
} else {
self.d
};
if out.len() != di || delta_beta.len() != self.k {
return false;
}
self.htbeta.apply_row(row, delta_beta, out)
}
pub fn apply_htbeta_row_transpose(
&self,
row: usize,
v: ArrayView1<'_, f64>,
out: &mut Array1<f64>,
fallback_op: Option<&RowHtbetaMatvec>,
) -> bool {
let di = if row < self.row_dims.len() {
self.row_dims[row]
} else {
self.d
};
if v.len() != di || out.len() != self.k {
return false;
}
self.htbeta
.apply_row_transpose_accumulate(row, v, out, di, self.k, fallback_op)
}
pub fn arrow_log_det(&self) -> (f64, Option<f64>) {
let mut log_det_tt = 0.0_f64;
for l in self.htt_factors.iter() {
for i in 0..l.nrows() {
log_det_tt += l[[i, i]].ln();
}
}
log_det_tt *= 2.0;
let log_det_schur = self.schur_factor.as_ref().map(|l| {
let mut s = 0.0_f64;
for i in 0..l.nrows() {
s += l[[i, i]].ln();
}
2.0 * s + self.cross_row_woodbury_log_det()
});
(log_det_tt, log_det_schur)
}
pub fn cross_row_woodbury_log_det(&self) -> f64 {
match self.cross_row_woodbury.as_ref() {
Some(w) => w.log_det().unwrap_or(f64::NAN),
None => 0.0,
}
}
pub fn latent_block_inverse_diagonal(&self) -> Result<Array1<f64>, ArrowSchurError> {
let Some(schur_factor) = self.schur_factor.as_ref() else {
return Err(ArrowSchurError::SchurFactorFailed {
reason: "latent_block_inverse_diagonal requires a dense Schur factor; \
the InexactPCG mode does not form one"
.to_string(),
});
};
if !self.htbeta_available() {
return Err(ArrowSchurError::SchurFactorFailed {
reason: "latent_block_inverse_diagonal requires the H_tβ coupling, \
but this cache's htbeta is Disabled"
.to_string(),
});
}
let n = self.undamped_factor_count();
let total_len = self.delta_t_len();
let mut out = Array1::<f64>::zeros(total_len);
let mut e_j = Array1::<f64>::zeros(self.d);
let mut w = Array1::<f64>::zeros(self.k);
for i in 0..n {
let di = self.row_dims[i];
let row_base = self.row_offsets[i];
let factor = self.undamped_factor(i);
for j in 0..di {
for c in 0..di {
e_j[c] = 0.0;
}
e_j[j] = 1.0;
let e_j_slice = e_j.slice(ndarray::s![..di]).to_owned();
let a = cholesky_solve_vector(factor, &e_j_slice);
w.fill(0.0);
if !self.apply_htbeta_row_transpose(i, a.view(), &mut w, None) {
return Err(ArrowSchurError::SchurFactorFailed {
reason: format!(
"latent_block_inverse_diagonal: H_βt^({i}) apply failed \
(htbeta cache could not supply row {i})"
),
});
}
let z = cholesky_solve_vector(schur_factor, &w);
let mut corr = 0.0_f64;
for c in 0..self.k {
corr += w[c] * z[c];
}
out[row_base + j] = a[j] + corr;
}
}
if let Some(woodbury) = self.cross_row_woodbury.as_ref() {
woodbury.subtract_inverse_diagonal(&mut out)?;
}
Ok(out)
}
pub fn full_inverse_apply(
&self,
w_t: ArrayView1<'_, f64>,
w_beta: ArrayView1<'_, f64>,
) -> Result<(Array1<f64>, Array1<f64>), ArrowSchurError> {
let (mut u_t, mut u_beta) = self.full_inverse_apply_base(w_t, w_beta)?;
if let Some(woodbury) = self.cross_row_woodbury.as_ref() {
let h0inv_w_t = u_t.clone();
woodbury.apply_inverse_correction(
h0inv_w_t.view(),
woodbury.source_entries(),
&mut u_t,
&mut u_beta,
)?;
}
Ok((u_t, u_beta))
}
pub(crate) fn full_inverse_apply_base(
&self,
w_t: ArrayView1<'_, f64>,
w_beta: ArrayView1<'_, f64>,
) -> Result<(Array1<f64>, Array1<f64>), ArrowSchurError> {
let total_len = self.delta_t_len();
if w_t.len() != total_len || w_beta.len() != self.k {
return Err(ArrowSchurError::SchurFactorFailed {
reason: format!(
"full_inverse_apply: rhs shapes (w_t={}, w_beta={}) != (delta_t_len={}, K={})",
w_t.len(),
w_beta.len(),
total_len,
self.k
),
});
}
let n = self.undamped_factor_count();
let mut y = Array1::<f64>::zeros(total_len);
let mut r_beta = w_beta.to_owned();
for i in 0..n {
let di = self.row_dims[i];
let base = self.row_offsets[i];
let factor = self.undamped_factor(i);
let w_row = w_t.slice(ndarray::s![base..base + di]).to_owned();
let y_row = cholesky_solve_vector(factor, &w_row);
if self.k > 0 {
let mut acc = Array1::<f64>::zeros(self.k);
if !self.apply_htbeta_row_transpose(i, y_row.view(), &mut acc, None) {
return Err(ArrowSchurError::SchurFactorFailed {
reason: format!(
"full_inverse_apply: H_βt^({i}) apply failed (htbeta cache \
could not supply row {i}; htbeta={:?}, di={}, k={})",
self.htbeta,
self.row_dims.get(i).copied().unwrap_or(self.d),
self.k
),
});
}
for c in 0..self.k {
r_beta[c] -= acc[c];
}
}
for j in 0..di {
y[base + j] = y_row[j];
}
}
let u_beta = if self.k > 0 {
self.schur_inverse_apply(r_beta.view())?
} else {
Array1::<f64>::zeros(0)
};
let mut u_t = y;
if self.k > 0 {
let mut cross = Array1::<f64>::zeros(self.d);
for i in 0..n {
let di = self.row_dims[i];
let base = self.row_offsets[i];
let mut cross_row = cross.slice_mut(ndarray::s![..di]);
cross_row.fill(0.0);
let mut cross_owned = cross_row.to_owned();
if !self.apply_htbeta_row(i, u_beta.view(), &mut cross_owned) {
return Err(ArrowSchurError::SchurFactorFailed {
reason: format!(
"full_inverse_apply: H_tβ^({i}) apply failed (htbeta cache \
could not supply row {i})"
),
});
}
let factor = self.undamped_factor(i);
let corr = cholesky_solve_vector(factor, &cross_owned);
for j in 0..di {
u_t[base + j] -= corr[j];
}
}
}
Ok((u_t, u_beta))
}
pub fn schur_inverse_apply(
&self,
rhs: ArrayView1<'_, f64>,
) -> Result<Array1<f64>, ArrowSchurError> {
let Some(schur_factor) = self.schur_factor.as_ref() else {
return Err(ArrowSchurError::SchurFactorFailed {
reason: "schur_inverse_apply requires a dense Schur factor; \
the InexactPCG mode does not form one"
.to_string(),
});
};
if rhs.len() != self.k {
return Err(ArrowSchurError::SchurFactorFailed {
reason: format!(
"schur_inverse_apply: rhs length {} != K {}",
rhs.len(),
self.k
),
});
}
let rhs_owned = rhs.to_owned();
Ok(cholesky_solve_vector(schur_factor, &rhs_owned))
}
pub fn schur_inverse_block(
&self,
block: std::ops::Range<usize>,
) -> Result<Array2<f64>, ArrowSchurError> {
let Some(schur_factor) = self.schur_factor.as_ref() else {
return Err(ArrowSchurError::SchurFactorFailed {
reason: "schur_inverse_block requires a dense Schur factor; \
the InexactPCG mode does not form one"
.to_string(),
});
};
if block.end > self.k {
return Err(ArrowSchurError::SchurFactorFailed {
reason: format!(
"schur_inverse_block: block end {} exceeds K {}",
block.end, self.k
),
});
}
let w = block.len();
let mut out = Array2::<f64>::zeros((w, w));
let mut e_j = Array1::<f64>::zeros(self.k);
for (jc, j) in block.clone().enumerate() {
e_j.fill(0.0);
e_j[j] = 1.0;
let col = cholesky_solve_vector(schur_factor, &e_j);
for (ic, i) in block.clone().enumerate() {
out[[ic, jc]] = col[i];
}
}
for ic in 0..w {
for jc in (ic + 1)..w {
let avg = 0.5 * (out[[ic, jc]] + out[[jc, ic]]);
out[[ic, jc]] = avg;
out[[jc, ic]] = avg;
}
}
Ok(out)
}
}