use super::*;
impl crate::solver::rho_optimizer::OuterHessianOperator for OwnedDenseOuterHessianOperator {
fn dim(&self) -> usize {
self.matrix.nrows()
}
fn matvec(&self, v: &Array1<f64>) -> Result<Array1<f64>, String> {
if v.len() != self.matrix.ncols() {
return Err(CustomFamilyError::DimensionMismatch {
reason: format!(
"batched dense outer Hessian matvec length mismatch: got {}, expected {}",
v.len(),
self.matrix.ncols()
),
}
.into());
}
Ok(self.matrix.dot(v))
}
fn apply_into(&self, v: &Array1<f64>, out: &mut Array1<f64>) -> Result<(), String> {
if v.len() != self.matrix.ncols() {
return Err(CustomFamilyError::DimensionMismatch {
reason: format!(
"batched dense outer Hessian apply_into input length mismatch: got {}, expected {}",
v.len(),
self.matrix.ncols()
),
}
.into());
}
if out.len() != self.matrix.nrows() {
return Err(CustomFamilyError::DimensionMismatch {
reason: format!(
"batched dense outer Hessian apply_into output length mismatch: got {}, expected {}",
out.len(),
self.matrix.nrows()
),
}
.into());
}
for (row, cell) in self.matrix.rows().into_iter().zip(out.iter_mut()) {
*cell = row.dot(v);
}
Ok(())
}
fn is_cheap_to_materialize(&self) -> bool {
true
}
}
pub(crate) struct LabeledOuterHessianOperator {
pub(crate) base: Arc<dyn crate::solver::rho_optimizer::OuterHessianOperator>,
pub(crate) physical_to_outer: Vec<Option<usize>>,
pub(crate) outer_dim: usize,
pub(crate) scratch: std::sync::Mutex<(ndarray::Array1<f64>, ndarray::Array1<f64>)>,
}
impl LabeledOuterHessianOperator {
pub(crate) fn new(
base: Arc<dyn crate::solver::rho_optimizer::OuterHessianOperator>,
layout: &PenaltyLabelLayout,
) -> Self {
let n_physical = layout.physical_to_outer.len();
Self {
base,
physical_to_outer: layout.physical_to_outer.clone(),
outer_dim: layout.initial_rho.len(),
scratch: std::sync::Mutex::new((
ndarray::Array1::zeros(n_physical),
ndarray::Array1::zeros(n_physical),
)),
}
}
}
impl crate::solver::rho_optimizer::OuterHessianOperator for LabeledOuterHessianOperator {
fn dim(&self) -> usize {
self.outer_dim
}
fn matvec(&self, v: &Array1<f64>) -> Result<Array1<f64>, String> {
if v.len() != self.outer_dim {
return Err(format!(
"labeled outer Hessian input length mismatch: got {}, expected {}",
v.len(),
self.outer_dim
));
}
let mut physical = Array1::<f64>::zeros(self.physical_to_outer.len());
for (physical_idx, outer_idx) in self.physical_to_outer.iter().enumerate() {
physical[physical_idx] = outer_idx.map(|idx| v[idx]).unwrap_or(0.0);
}
let physical_out = self.base.matvec(&physical)?;
if physical_out.len() != self.physical_to_outer.len() {
return Err(format!(
"labeled outer Hessian physical matvec length mismatch: got {}, expected {}",
physical_out.len(),
self.physical_to_outer.len()
));
}
let mut out = Array1::<f64>::zeros(self.outer_dim);
for (physical_idx, outer_idx) in self.physical_to_outer.iter().enumerate() {
if let Some(outer_idx) = *outer_idx {
out[outer_idx] += physical_out[physical_idx];
}
}
Ok(out)
}
fn apply_into(
&self,
v: &ndarray::Array1<f64>,
out: &mut ndarray::Array1<f64>,
) -> Result<(), String> {
if v.len() != self.outer_dim {
return Err(format!(
"labeled outer Hessian apply_into input length mismatch: got {}, expected {}",
v.len(),
self.outer_dim
));
}
if out.len() != self.outer_dim {
return Err(format!(
"labeled outer Hessian apply_into output length mismatch: got {}, expected {}",
out.len(),
self.outer_dim
));
}
let mut guard = self
.scratch
.lock()
.map_err(|_| "labeled outer Hessian scratch lock poisoned".to_string())?;
let (physical_in, physical_out) = &mut *guard;
for (physical_idx, outer_idx) in self.physical_to_outer.iter().enumerate() {
physical_in[physical_idx] = outer_idx.map(|idx| v[idx]).unwrap_or(0.0);
}
self.base.apply_into(physical_in, physical_out)?;
if physical_out.len() != self.physical_to_outer.len() {
return Err(format!(
"labeled outer Hessian physical apply_into length mismatch: got {}, expected {}",
physical_out.len(),
self.physical_to_outer.len()
));
}
out.fill(0.0);
for (physical_idx, outer_idx) in self.physical_to_outer.iter().enumerate() {
if let Some(outer_idx) = *outer_idx {
out[outer_idx] += physical_out[physical_idx];
}
}
Ok(())
}
fn mul_mat(&self, factor: ndarray::ArrayView2<'_, f64>) -> Result<Array2<f64>, String> {
if factor.nrows() != self.outer_dim {
return Err(format!(
"labeled outer Hessian factor row mismatch: got {}, expected {}",
factor.nrows(),
self.outer_dim
));
}
let mut physical_factor =
Array2::<f64>::zeros((self.physical_to_outer.len(), factor.ncols()));
for (physical_idx, outer_idx) in self.physical_to_outer.iter().enumerate() {
if let Some(outer_idx) = *outer_idx {
physical_factor
.row_mut(physical_idx)
.assign(&factor.row(outer_idx));
}
}
let physical_out = self.base.mul_mat(physical_factor.view())?;
if physical_out.nrows() != self.physical_to_outer.len() {
return Err(format!(
"labeled outer Hessian physical output row mismatch: got {}, expected {}",
physical_out.nrows(),
self.physical_to_outer.len()
));
}
let mut out = Array2::<f64>::zeros((self.outer_dim, factor.ncols()));
for (physical_idx, outer_idx) in self.physical_to_outer.iter().enumerate() {
if let Some(outer_idx) = *outer_idx {
let physical_row = physical_out.row(physical_idx);
out.row_mut(outer_idx).scaled_add(1.0, &physical_row);
}
}
Ok(out)
}
fn is_cheap_to_materialize(&self) -> bool {
self.base.is_cheap_to_materialize()
}
fn materialization_capability(
&self,
) -> crate::solver::rho_optimizer::OuterHessianMaterialization {
self.base.materialization_capability()
}
}
pub(crate) fn custom_family_batched_outer_hessian_operator<F: CustomFamily>(
family: &F,
states: &[ParameterBlockState],
specs: &[ParameterBlockSpec],
derivative_blocks: &[Vec<CustomFamilyBlockPsiDerivative>],
rho: &Array1<f64>,
workspace: Option<Arc<dyn ExactNewtonJointHessianWorkspace>>,
eval_mode: EvalMode,
) -> Result<Option<Arc<dyn crate::solver::rho_optimizer::OuterHessianOperator>>, String> {
if eval_mode != EvalMode::ValueGradientHessian {
return Ok(None);
}
let Some(terms) =
family.batched_outer_hessian_terms(states, specs, derivative_blocks, rho, workspace)?
else {
return Ok(None);
};
match terms.outer_hessian {
crate::solver::rho_optimizer::HessianResult::Operator(operator) => Ok(Some(operator)),
crate::solver::rho_optimizer::HessianResult::Analytic(matrix) => {
Ok(Some(Arc::new(OwnedDenseOuterHessianOperator { matrix })))
}
crate::solver::rho_optimizer::HessianResult::Unavailable => Ok(None),
}
}
pub(crate) fn outer_efs_result_to_joint_hyper_efs_result(
efs_eval: crate::solver::rho_optimizer::EfsEval,
warm_start: ConstrainedWarmStart,
inner_converged: bool,
) -> CustomFamilyJointHyperEfsResult {
CustomFamilyJointHyperEfsResult {
efs_eval,
warm_start: CustomFamilyWarmStart { inner: warm_start },
inner_converged,
}
}
pub(crate) fn with_block_geometry<F: CustomFamily + ?Sized, T>(
family: &F,
block_states: &[ParameterBlockState],
spec: &ParameterBlockSpec,
block_idx: usize,
f: impl FnOnce(&DesignMatrix, &Array1<f64>) -> Result<T, String>,
) -> Result<T, String> {
if family.block_geometry_is_dynamic() {
let (x_dyn, off_dyn) = family.block_geometry(block_states, spec)?;
let expected_rows = spec.solver_design().nrows();
if x_dyn.nrows() != expected_rows {
return Err(CustomFamilyError::DimensionMismatch {
reason: format!(
"block {block_idx} dynamic design row mismatch: got {}, expected {}",
x_dyn.nrows(),
expected_rows
),
}
.into());
}
if x_dyn.ncols() != spec.design.ncols() {
return Err(CustomFamilyError::DimensionMismatch {
reason: format!(
"block {block_idx} dynamic design col mismatch: got {}, expected {}",
x_dyn.ncols(),
spec.design.ncols()
),
}
.into());
}
if off_dyn.len() != expected_rows {
return Err(CustomFamilyError::DimensionMismatch {
reason: format!(
"block {block_idx} dynamic offset length mismatch: got {}, expected {}",
off_dyn.len(),
expected_rows
),
}
.into());
}
f(&x_dyn, &off_dyn)
} else {
f(spec.solver_design(), spec.solver_offset())
}
}