use crate::linalg::faer_ndarray::{FaerArrayView, factorize_symmetricwith_fallback};
use crate::matrix::FactorizedSystem;
use crate::solver::estimate::EstimationError;
use crate::solver::rho_optimizer::{
HessianResult, OuterCapability, OuterHessianOperator, OuterObjective, OuterPlan, OuterResult,
};
use faer::Side;
use ndarray::{Array1, Array2};
use rayon::iter::{IntoParallelIterator, ParallelIterator};
use std::sync::Arc;
pub const PER_ATOM_EFS_MIN_RHO_DIM: usize = 64;
pub(crate) const PER_ATOM_MAX_STEP: f64 = 5.0;
pub(crate) const PER_ATOM_MAX_BACKTRACK: usize = 8;
pub(crate) const PER_ATOM_NEGLIGIBLE_STEP: f64 = 1e-12;
pub(crate) const PER_ATOM_COST_DESCENT_TOL: f64 = 1e-12;
#[inline]
pub fn is_frontier_rho_scale(rho_dim: usize) -> bool {
rho_dim >= PER_ATOM_EFS_MIN_RHO_DIM
}
pub fn per_atom_efs_eligible(cap: &OuterCapability) -> bool {
cap.all_penalty_like()
&& cap.fixed_point_available
&& !cap.disable_fixed_point
&& is_frontier_rho_scale(cap.theta_layout().rho_dim())
}
pub struct PerAtomEfsResult {
pub rho: Array1<f64>,
pub final_value: f64,
pub iterations: usize,
pub final_step_inf_norm: f64,
pub converged: bool,
}
impl PerAtomEfsResult {
pub fn into_outer_result(self, plan_used: OuterPlan) -> OuterResult {
OuterResult {
rho: self.rho,
final_value: self.final_value,
iterations: self.iterations,
final_grad_norm: Some(self.final_step_inf_norm),
final_gradient: None,
final_hessian: None,
converged: self.converged,
plan_used,
operator_trust_radius: None,
operator_stop_reason: None,
criterion_certificate: None,
rho_uncertainty_diagnostic: None,
}
}
}
#[derive(Clone, Debug)]
pub struct SharedBorderTopology {
pub(crate) border_axes: Vec<usize>,
pub(crate) rho_dim: usize,
}
impl SharedBorderTopology {
pub fn disjoint(rho_dim: usize) -> Self {
Self {
border_axes: Vec::new(),
rho_dim,
}
}
pub fn with_border_axes(rho_dim: usize, axes: Vec<usize>) -> Result<Self, String> {
let mut border_axes = axes;
border_axes.sort_unstable();
border_axes.dedup();
if let Some(&last) = border_axes.last() {
if last >= rho_dim {
return Err(format!(
"SharedBorderTopology: border axis {last} out of range (rho_dim = {rho_dim})"
));
}
}
Ok(Self {
border_axes,
rho_dim,
})
}
#[inline]
pub fn border_axes(&self) -> &[usize] {
&self.border_axes
}
#[inline]
pub fn border_count(&self) -> usize {
self.border_axes.len()
}
#[inline]
pub(crate) fn rho_dim(&self) -> usize {
self.rho_dim
}
}
#[derive(Clone, Debug)]
pub struct PerAtomEfsConfig {
pub tolerance: f64,
pub max_iter: usize,
pub lower: Array1<f64>,
pub upper: Array1<f64>,
}
impl PerAtomEfsConfig {
pub fn new(tolerance: f64, max_iter: usize, lower: Array1<f64>, upper: Array1<f64>) -> Self {
Self {
tolerance,
max_iter,
lower,
upper,
}
}
}
#[inline]
pub(crate) fn project_axis(value: f64, lo: f64, hi: f64) -> f64 {
value.max(lo).min(hi)
}
pub(crate) fn project_to_bounds(rho: &Array1<f64>, cfg: &PerAtomEfsConfig) -> Array1<f64> {
let mut out = rho.clone();
for i in 0..out.len() {
out[i] = project_axis(out[i], cfg.lower[i], cfg.upper[i]);
}
out
}
#[inline]
pub(crate) fn sanitize_step(raw: f64) -> f64 {
if raw.is_finite() {
raw.clamp(-PER_ATOM_MAX_STEP, PER_ATOM_MAX_STEP)
} else {
0.0
}
}
pub(crate) fn border_hessian_block(
topology: &SharedBorderTopology,
operator: &Arc<dyn OuterHessianOperator>,
rho: &Array1<f64>,
) -> Result<Array2<f64>, EstimationError> {
let m = topology.border_count();
let border = topology.border_axes();
let mut block = Array2::<f64>::zeros((m, m));
if operator.dim() != rho.len() {
return Err(EstimationError::RemlOptimizationFailed(format!(
"per-atom border θ-HVP operator dim {} != rho_dim {}",
operator.dim(),
rho.len()
)));
}
let cols: Result<Vec<(usize, Array1<f64>)>, EstimationError> = (0..m)
.into_par_iter()
.map(|j| {
let mut e_j = Array1::<f64>::zeros(rho.len());
e_j[border[j]] = 1.0;
let hv = operator.matvec(&e_j).map_err(|reason| {
EstimationError::RemlOptimizationFailed(format!(
"per-atom border θ-HVP operator matvec failed: {reason}"
))
})?;
Ok((j, hv))
})
.collect();
for (j, hv) in cols? {
for (row, &axis) in border.iter().enumerate() {
block[[row, j]] = hv[axis];
}
}
for r in 0..m {
for c in (r + 1)..m {
let s = 0.5 * (block[[r, c]] + block[[c, r]]);
block[[r, c]] = s;
block[[c, r]] = s;
}
}
Ok(block)
}
fn solve_shared_border_block(
topology: &SharedBorderTopology,
mut block: Array2<f64>,
gradient: &Array1<f64>,
) -> Result<Array1<f64>, EstimationError> {
let m = topology.border_count();
let mut step = Array1::<f64>::zeros(topology.rho_dim());
if m == 0 {
return Ok(step);
}
let border = topology.border_axes();
let g_border_inf = border
.iter()
.map(|&i| gradient[i].abs())
.fold(0.0_f64, f64::max);
if g_border_inf <= PER_ATOM_NEGLIGIBLE_STEP {
return Ok(step);
}
if block.dim() != (m, m) {
return Err(EstimationError::RemlOptimizationFailed(format!(
"per-atom shared-border block shape {:?} != expected {m}x{m}",
block.dim()
)));
}
let diag_scale = {
let mut acc = 0.0_f64;
for r in 0..m {
acc += block[[r, r]].abs();
}
acc / (m as f64)
};
let ridge = if diag_scale.is_finite() && diag_scale > 0.0 {
1e-8 * diag_scale
} else {
1e-8
};
for r in 0..m {
block[[r, r]] += ridge;
}
let mut g_border = Array1::<f64>::zeros(m);
for (row, &axis) in border.iter().enumerate() {
g_border[row] = gradient[axis];
}
let factor = {
let view = FaerArrayView::new(&block);
factorize_symmetricwith_fallback(view.as_ref(), Side::Lower).map_err(|err| {
EstimationError::RemlOptimizationFailed(format!(
"per-atom shared-border {m}×{m} factorization failed: {err:?}"
))
})?
};
let delta = FactorizedSystem::solve(&factor, &g_border).map_err(|reason| {
EstimationError::RemlOptimizationFailed(format!(
"per-atom shared-border {m}×{m} solve failed: {reason}"
))
})?;
for (row, &axis) in border.iter().enumerate() {
step[axis] = sanitize_step(-delta[row]);
}
Ok(step)
}
pub(crate) fn shared_border_correction(
topology: &SharedBorderTopology,
operator: &Arc<dyn OuterHessianOperator>,
rho: &Array1<f64>,
gradient: &Array1<f64>,
) -> Result<Array1<f64>, EstimationError> {
let m = topology.border_count();
if m == 0 {
return Ok(Array1::<f64>::zeros(topology.rho_dim()));
}
let block = border_hessian_block(topology, operator, rho)?;
solve_shared_border_block(topology, block, gradient)
}
pub(crate) fn backtrack_cost(
obj: &mut dyn OuterObjective,
rho: &Array1<f64>,
full_step: &Array1<f64>,
current_cost: f64,
cfg: &PerAtomEfsConfig,
) -> Result<Option<(Array1<f64>, f64, f64)>, EstimationError> {
let mut alpha = 1.0_f64;
let descent_slack = PER_ATOM_COST_DESCENT_TOL * current_cost.abs().max(1.0);
for _ in 0..=PER_ATOM_MAX_BACKTRACK {
let mut trial = rho.clone();
for i in 0..trial.len() {
trial[i] += alpha * full_step[i];
}
let trial = project_to_bounds(&trial, cfg);
match obj.eval_cost(&trial) {
Ok(cost) if cost.is_finite() && cost <= current_cost + descent_slack => {
return Ok(Some((trial, cost, alpha)));
}
Ok(_) => {}
Err(_) => {}
}
alpha *= 0.5;
}
Ok(None)
}
pub fn run_per_atom_efs(
obj: &mut dyn OuterObjective,
seed: &Array1<f64>,
cfg: &PerAtomEfsConfig,
topology: &SharedBorderTopology,
) -> Result<PerAtomEfsResult, EstimationError> {
let rho_dim = seed.len();
if cfg.lower.len() != rho_dim || cfg.upper.len() != rho_dim {
return Err(EstimationError::InvalidInput(format!(
"per-atom EFS bounds dim mismatch: lower={}, upper={}, rho={}",
cfg.lower.len(),
cfg.upper.len(),
rho_dim
)));
}
if topology.rho_dim() != rho_dim {
return Err(EstimationError::InvalidInput(format!(
"per-atom EFS topology rho_dim {} != seed dim {}",
topology.rho_dim(),
rho_dim
)));
}
let mut rho = project_to_bounds(seed, cfg);
let mut iterations = 0usize;
let mut final_step_inf = f64::INFINITY;
let mut last_cost = f64::INFINITY;
let mut converged = false;
for _ in 0..cfg.max_iter.max(1) {
iterations += 1;
let efs = obj.eval_efs(&rho)?;
if !efs.cost.is_finite() {
return Err(EstimationError::RemlOptimizationFailed(
"per-atom EFS: non-finite cost from eval_efs".to_string(),
));
}
if efs.steps.len() != rho_dim {
return Err(EstimationError::RemlOptimizationFailed(format!(
"per-atom EFS: step length {} != rho_dim {}",
efs.steps.len(),
rho_dim
)));
}
last_cost = efs.cost;
let mut full_step: Array1<f64> = {
let raw = efs.steps.as_slice();
let clamped: Vec<f64> = (0..rho_dim)
.into_par_iter()
.map(|i| sanitize_step(raw[i]))
.collect();
Array1::from_vec(clamped)
};
if topology.border_count() > 0 {
let outer_eval = obj.eval(&rho)?;
let gradient = outer_eval.gradient.clone();
if gradient.len() == rho_dim {
let border_step_result = match &outer_eval.hessian {
HessianResult::Analytic(hessian)
if hessian.nrows() == rho_dim && hessian.ncols() == rho_dim =>
{
let m = topology.border_count();
let border = topology.border_axes();
let mut block = Array2::<f64>::zeros((m, m));
for (r, &axis_r) in border.iter().enumerate() {
for (c, &axis_c) in border.iter().enumerate() {
block[[r, c]] = hessian[[axis_r, axis_c]];
}
}
solve_shared_border_block(topology, block, &gradient)
}
HessianResult::Operator(op) => {
let operator = Arc::clone(op);
shared_border_correction(topology, &operator, &rho, &gradient)
}
_ => {
log::debug!(
"[PER-ATOM-EFS] no usable outer Hessian; shared-border \
correction deferred to decoupled step for this iter"
);
Ok(Array1::<f64>::zeros(rho_dim))
}
};
match border_step_result {
Ok(border_step) => {
for &axis in topology.border_axes() {
full_step[axis] = border_step[axis];
}
}
Err(err) => {
log::debug!("[PER-ATOM-EFS] shared-border correction skipped: {err}");
}
}
}
}
let step_inf = full_step.iter().map(|s| s.abs()).fold(0.0_f64, f64::max);
final_step_inf = step_inf;
let margin = cfg.tolerance.max(PER_ATOM_NEGLIGIBLE_STEP);
let cost_resolved_below_margin = match efs.logdet_enclosure_gap {
Some(gap) => {
crate::solver::logdet_bounds::LogdetEnclosure::gap_resolves_margin(gap, margin)
}
None => true,
};
if step_inf < margin {
if cost_resolved_below_margin {
converged = true;
break;
}
log::info!(
"[PER-ATOM-EFS] step within tolerance {margin:.3e} but cost logdet enclosure gap \
{:.3e} exceeds it; refining the bound before declaring convergence",
efs.logdet_enclosure_gap.unwrap_or(0.0)
);
}
match backtrack_cost(obj, &rho, &full_step, efs.cost, cfg)? {
Some((rho_new, cost_new, _alpha)) => {
rho = rho_new;
last_cost = cost_new;
}
None => {
log::info!(
"[PER-ATOM-EFS] step rejected after {} halvings at cost={:.6e} \
(rho_dim={}, border={}); reporting stall",
PER_ATOM_MAX_BACKTRACK,
efs.cost,
rho_dim,
topology.border_count(),
);
break;
}
}
}
Ok(PerAtomEfsResult {
rho,
final_value: last_cost,
iterations,
final_step_inf_norm: final_step_inf,
converged,
})
}
#[cfg(test)]
mod tests {
use super::*;
use crate::solver::rho_optimizer::{
DeclaredHessianForm, Derivative, EfsEval, OuterEval, SeedOutcome,
};
use ndarray::array;
pub(crate) struct QuadraticOperator {
pub(crate) a: Array2<f64>,
}
impl OuterHessianOperator for QuadraticOperator {
fn dim(&self) -> usize {
self.a.nrows()
}
fn matvec(&self, v: &Array1<f64>) -> Result<Array1<f64>, String> {
Ok(self.a.dot(v))
}
}
pub(crate) struct QuadraticObjective {
pub(crate) a: Array2<f64>,
pub(crate) target: Array1<f64>,
}
impl QuadraticObjective {
pub(crate) fn grad(&self, rho: &Array1<f64>) -> Array1<f64> {
self.a.dot(&(rho - &self.target))
}
pub(crate) fn cost(&self, rho: &Array1<f64>) -> f64 {
let e = rho - &self.target;
0.5 * e.dot(&self.a.dot(&e))
}
}
impl OuterObjective for QuadraticObjective {
fn capability(&self) -> OuterCapability {
OuterCapability {
gradient: Derivative::Analytic,
hessian: DeclaredHessianForm::Dense,
n_params: self.a.nrows(),
psi_dim: 0,
fixed_point_available: true,
barrier_config: None,
prefer_gradient_only: false,
disable_fixed_point: false,
}
}
fn eval_cost(&mut self, rho: &Array1<f64>) -> Result<f64, EstimationError> {
Ok(self.cost(rho))
}
fn eval(&mut self, rho: &Array1<f64>) -> Result<OuterEval, EstimationError> {
Ok(OuterEval {
cost: self.cost(rho),
gradient: self.grad(rho),
hessian: HessianResult::Operator(Arc::new(QuadraticOperator { a: self.a.clone() })),
inner_beta_hint: None,
})
}
fn eval_efs(&mut self, rho: &Array1<f64>) -> Result<EfsEval, EstimationError> {
let g = self.grad(rho);
let steps: Vec<f64> = (0..rho.len()).map(|i| -g[i] / self.a[[i, i]]).collect();
Ok(EfsEval {
cost: self.cost(rho),
steps,
beta: None,
psi_gradient: None,
psi_indices: None,
inner_hessian_scale: None,
logdet_enclosure_gap: None,
})
}
fn reset(&mut self) {}
fn seed_inner_state(&mut self, beta: &Array1<f64>) -> Result<SeedOutcome, EstimationError> {
if !beta.is_empty() {
assert_eq!(beta.len(), self.a.nrows());
}
Ok(SeedOutcome::NoSlot)
}
}
pub(crate) fn wide_bounds(dim: usize) -> PerAtomEfsConfig {
PerAtomEfsConfig::new(
1e-9,
200,
Array1::from_elem(dim, -50.0),
Array1::from_elem(dim, 50.0),
)
}
#[test]
pub(crate) fn with_border_axes_sorts_dedups_and_validates() {
let t = SharedBorderTopology::with_border_axes(8, vec![5, 1, 5, 3]).expect("topology");
assert_eq!(t.border_axes(), &[1, 3, 5]);
assert_eq!(t.border_count(), 3);
let err = SharedBorderTopology::with_border_axes(4, vec![0, 4]).unwrap_err();
assert!(err.contains("out of range"), "got: {err}");
let t = SharedBorderTopology::with_border_axes(4, Vec::new()).expect("empty");
assert_eq!(t.border_count(), 0);
}
#[test]
pub(crate) fn decoupled_primary_converges_on_separable_objective() {
let dim = 96; let a = Array2::from_shape_fn(
(dim, dim),
|(i, j)| {
if i == j { 1.0 + (i % 5) as f64 } else { 0.0 }
},
);
let target = Array1::from_shape_fn(dim, |i| ((i as f64) * 0.37).sin() * 2.0);
let mut obj = QuadraticObjective {
a,
target: target.clone(),
};
let cfg = wide_bounds(dim);
let topology = SharedBorderTopology::disjoint(dim);
let seed = Array1::zeros(dim);
let result = run_per_atom_efs(&mut obj, &seed, &cfg, &topology).expect("run");
assert!(result.converged, "separable quadratic must converge");
for i in 0..dim {
assert!(
(result.rho[i] - target[i]).abs() < 1e-6,
"coord {i}: {} vs target {}",
result.rho[i],
target[i]
);
}
assert!(result.final_value < 1e-10);
}
#[test]
pub(crate) fn border_correction_solves_the_coupled_block() {
let dim = 6;
let mut a = Array2::<f64>::eye(dim) * 2.0;
a[[0, 1]] = 0.4;
a[[1, 0]] = 0.4;
let target = array![1.0, -2.0, 0.5, 0.0, -1.0, 3.0];
let mut obj = QuadraticObjective {
a,
target: target.clone(),
};
let cfg = wide_bounds(dim);
let topology = SharedBorderTopology::with_border_axes(dim, vec![0, 1]).expect("topology");
let seed = Array1::zeros(dim);
let result = run_per_atom_efs(&mut obj, &seed, &cfg, &topology).expect("run");
assert!(result.converged, "coupled quadratic must converge");
for i in 0..dim {
assert!(
(result.rho[i] - target[i]).abs() < 1e-5,
"coord {i}: {} vs target {}",
result.rho[i],
target[i]
);
}
assert!(
result.final_step_inf_norm < 1e-8,
"correction must null at the stationary point (got {})",
result.final_step_inf_norm
);
}
#[test]
pub(crate) fn theta_hvp_forwards_exactly_to_the_operator_action() {
fn theta_hvp_matrix_free(
operator: &Arc<dyn OuterHessianOperator>,
v: &Array1<f64>,
) -> Result<Array1<f64>, EstimationError> {
if operator.dim() != v.len() {
return Err(EstimationError::RemlOptimizationFailed(format!(
"per-atom θ-HVP operator dim {} != vector len {}",
operator.dim(),
v.len()
)));
}
operator.matvec(v).map_err(|reason| {
EstimationError::RemlOptimizationFailed(format!(
"per-atom θ-HVP operator matvec failed (dim={}): {reason}",
v.len()
))
})
}
let a = array![[2.0, 0.3, 0.0], [0.3, 1.5, -0.2], [0.0, -0.2, 4.0]];
let v = array![0.3, -1.1, 0.9];
let exact = a.dot(&v);
let op: Arc<dyn OuterHessianOperator> = Arc::new(QuadraticOperator { a: a.clone() });
let hv_op = theta_hvp_matrix_free(&op, &v).expect("op hvp");
for i in 0..3 {
assert_eq!(hv_op[i].to_bits(), exact[i].to_bits());
}
let wrong = array![1.0, 2.0];
assert!(theta_hvp_matrix_free(&op, &wrong).is_err());
}
#[test]
pub(crate) fn full_border_reduces_to_dense_newton_in_one_correction() {
let dim = 4;
let a = array![
[3.0, 0.2, 0.0, 0.1],
[0.2, 2.0, 0.1, 0.0],
[0.0, 0.1, 1.5, 0.2],
[0.1, 0.0, 0.2, 2.5]
];
let target = array![0.3, -0.6, 1.2, -0.1];
let mut obj = QuadraticObjective {
a,
target: target.clone(),
};
let cfg = wide_bounds(dim);
let topology =
SharedBorderTopology::with_border_axes(dim, (0..dim).collect()).expect("topology");
let seed = Array1::from_elem(dim, 2.0);
let result = run_per_atom_efs(&mut obj, &seed, &cfg, &topology).expect("run");
assert!(result.converged);
for i in 0..dim {
assert!((result.rho[i] - target[i]).abs() < 1e-5);
}
}
pub(crate) struct EnclosureGapObjective {
pub(crate) inner: QuadraticObjective,
pub(crate) gap: f64,
}
impl OuterObjective for EnclosureGapObjective {
fn capability(&self) -> OuterCapability {
self.inner.capability()
}
fn eval_cost(&mut self, rho: &Array1<f64>) -> Result<f64, EstimationError> {
self.inner.eval_cost(rho)
}
fn eval(&mut self, rho: &Array1<f64>) -> Result<OuterEval, EstimationError> {
self.inner.eval(rho)
}
fn eval_efs(&mut self, rho: &Array1<f64>) -> Result<EfsEval, EstimationError> {
Ok(self
.inner
.eval_efs(rho)?
.with_logdet_enclosure_gap(Some(self.gap)))
}
fn reset(&mut self) {
self.inner.reset()
}
fn seed_inner_state(&mut self, beta: &Array1<f64>) -> Result<SeedOutcome, EstimationError> {
self.inner.seed_inner_state(beta)
}
}
#[test]
pub(crate) fn efs_refuses_to_converge_below_the_logdet_enclosure_margin() {
let dim = 96; let a = Array2::from_shape_fn(
(dim, dim),
|(i, j)| if i == j { 1.0 + (i % 5) as f64 } else { 0.0 },
);
let target = Array1::from_shape_fn(dim, |i| ((i as f64) * 0.37).sin() * 2.0);
let mut cfg = wide_bounds(dim);
cfg.tolerance = 1e-6;
let topology = SharedBorderTopology::disjoint(dim);
let seed = Array1::zeros(dim);
let mut wide = EnclosureGapObjective {
inner: QuadraticObjective {
a: a.clone(),
target: target.clone(),
},
gap: 1e-2,
};
let wide_result = run_per_atom_efs(&mut wide, &seed, &cfg, &topology).expect("wide run");
assert!(
!wide_result.converged,
"an enclosure gap wider than the step tolerance must block convergence"
);
let mut tight = EnclosureGapObjective {
inner: QuadraticObjective {
a,
target: target.clone(),
},
gap: 1e-9,
};
let tight_result = run_per_atom_efs(&mut tight, &seed, &cfg, &topology).expect("tight run");
assert!(
tight_result.converged,
"an enclosure gap below the step tolerance must not obstruct convergence"
);
}
}