use std::sync::Arc;
use ndarray::{Array1, Array2, ArrayView2};
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum OuterHessianMaterialization {
Unavailable,
RepeatedHvp,
BatchedHvp,
Explicit,
}
impl OuterHessianMaterialization {
pub fn is_available(self) -> bool {
!matches!(self, Self::Unavailable)
}
}
#[derive(Debug, Clone)]
pub enum OuterStrategyError {
OperatorShape { reason: String },
NonFiniteHessian { reason: String },
RhoBlockShape { reason: String },
}
impl_reason_error_boilerplate! {
OuterStrategyError {
OperatorShape,
NonFiniteHessian,
RhoBlockShape,
}
}
pub trait OuterHessianOperator: Send + Sync {
fn dim(&self) -> usize;
fn matvec(&self, v: &Array1<f64>) -> Result<Array1<f64>, String>;
fn apply_into(&self, v: &Array1<f64>, out: &mut Array1<f64>) -> Result<(), String> {
let result = self.matvec(v)?;
if result.len() != out.len() {
return Err(format!(
"outer Hessian operator matvec produced length {} but expected {}",
result.len(),
out.len()
));
}
out.assign(&result);
Ok(())
}
fn is_cheap_to_materialize(&self) -> bool {
false
}
fn materialization_capability(&self) -> OuterHessianMaterialization {
if self.is_cheap_to_materialize() {
OuterHessianMaterialization::RepeatedHvp
} else {
OuterHessianMaterialization::Unavailable
}
}
fn mul_mat(&self, factor: ArrayView2<'_, f64>) -> Result<Array2<f64>, String> {
use rayon::iter::{IntoParallelIterator, ParallelIterator};
let dim = self.dim();
if factor.nrows() != dim {
return Err(OuterStrategyError::OperatorShape {
reason: format!(
"outer Hessian operator factor row count mismatch: got {}, expected {}",
factor.nrows(),
dim
),
}
.into());
}
let m = factor.ncols();
let cols: Result<Vec<Array1<f64>>, String> = (0..m)
.into_par_iter()
.map(|j| {
let col = factor.column(j).to_owned();
let hv = self.matvec(&col)?;
if hv.len() != dim {
return Err(OuterStrategyError::OperatorShape {
reason: format!(
"outer Hessian operator matvec length mismatch: got {}, expected {}",
hv.len(),
dim
),
}
.into());
}
Ok(hv)
})
.collect();
let cols = cols?;
let mut out = Array2::<f64>::zeros((dim, m));
for (j, hv) in cols.into_iter().enumerate() {
out.column_mut(j).assign(&hv);
}
Ok(out)
}
fn materialize_dense(&self) -> Result<Array2<f64>, String> {
let dim = self.dim();
let identity = Array2::<f64>::eye(dim);
let mut dense = self.mul_mat(identity.view())?;
if dense.nrows() != dim || dense.ncols() != dim {
return Err(OuterStrategyError::OperatorShape {
reason: format!(
"outer Hessian operator mul_mat returned {}x{}, expected {}x{}",
dense.nrows(),
dense.ncols(),
dim,
dim
),
}
.into());
}
for row in 0..dim {
for col in (row + 1)..dim {
let sym = 0.5 * (dense[[row, col]] + dense[[col, row]]);
dense[[row, col]] = sym;
dense[[col, row]] = sym;
}
}
if !dense.iter().all(|value| value.is_finite()) {
return Err(OuterStrategyError::NonFiniteHessian {
reason: "outer Hessian dense materialization produced non-finite entries"
.to_string(),
}
.into());
}
Ok(dense)
}
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum Derivative {
Analytic,
Unavailable,
}
#[derive(Clone, Copy, Debug, PartialEq)]
pub enum DeclaredHessianForm {
Dense,
Operator {
materialization: OuterHessianMaterialization,
estimated_materialization_cost: Option<f64>,
},
Either,
Unavailable,
}
impl DeclaredHessianForm {
pub const fn is_analytic(self) -> bool {
!matches!(self, DeclaredHessianForm::Unavailable)
}
pub const fn is_operator_only(self) -> bool {
matches!(self, DeclaredHessianForm::Operator { .. })
}
pub const fn is_dense_only(self) -> bool {
matches!(self, DeclaredHessianForm::Dense)
}
}
pub struct OuterEval {
pub cost: f64,
pub gradient: Array1<f64>,
pub hessian: HessianResult,
pub inner_beta_hint: Option<Array1<f64>>,
}
impl OuterEval {
pub fn infeasible(n_params: usize) -> Self {
Self {
cost: f64::INFINITY,
gradient: Array1::zeros(n_params),
hessian: HessianResult::Unavailable,
inner_beta_hint: None,
}
}
pub fn value_only(cost: f64, n_params: usize, inner_beta_hint: Option<Array1<f64>>) -> Self {
Self {
cost,
gradient: Array1::zeros(n_params),
hessian: HessianResult::Unavailable,
inner_beta_hint,
}
}
}
pub enum HessianResult {
Analytic(Array2<f64>),
Operator(Arc<dyn OuterHessianOperator>),
Unavailable,
}
impl Clone for OuterEval {
fn clone(&self) -> Self {
Self {
cost: self.cost,
gradient: self.gradient.clone(),
hessian: self.hessian.clone(),
inner_beta_hint: self.inner_beta_hint.clone(),
}
}
}
impl std::fmt::Debug for OuterEval {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("OuterEval")
.field("cost", &self.cost)
.field("gradient", &self.gradient)
.field("hessian", &self.hessian)
.finish()
}
}
impl Clone for HessianResult {
fn clone(&self) -> Self {
match self {
Self::Analytic(h) => Self::Analytic(h.clone()),
Self::Operator(op) => Self::Operator(Arc::clone(op)),
Self::Unavailable => Self::Unavailable,
}
}
}
impl std::fmt::Debug for HessianResult {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Analytic(h) => f
.debug_tuple("Analytic")
.field(&format!("{}x{}", h.nrows(), h.ncols()))
.finish(),
Self::Operator(op) => f
.debug_tuple("Operator")
.field(&format!("dim={}", op.dim()))
.finish(),
Self::Unavailable => f.write_str("Unavailable"),
}
}
}
impl HessianResult {
pub fn is_analytic(&self) -> bool {
matches!(
self,
HessianResult::Analytic(_) | HessianResult::Operator(_)
)
}
pub fn dim(&self) -> Option<usize> {
match self {
HessianResult::Analytic(h) => Some(h.nrows()),
HessianResult::Operator(op) => Some(op.dim()),
HessianResult::Unavailable => None,
}
}
pub fn materialize_dense(&self) -> Result<Option<Array2<f64>>, String> {
match self {
HessianResult::Analytic(h) => Ok(Some(h.clone())),
HessianResult::Operator(op) => op.materialize_dense().map(Some),
HessianResult::Unavailable => Ok(None),
}
}
}
#[derive(Clone, Debug)]
pub struct EfsEval {
pub cost: f64,
pub steps: Vec<f64>,
pub beta: Option<Array1<f64>>,
pub psi_gradient: Option<Array1<f64>>,
pub psi_indices: Option<Vec<usize>>,
pub inner_hessian_scale: Option<f64>,
pub logdet_enclosure_gap: Option<f64>,
}
impl EfsEval {
pub fn with_logdet_enclosure_gap(mut self, gap: Option<f64>) -> Self {
self.logdet_enclosure_gap = gap;
self
}
}