use super::*;
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum OuterHessianMaterialization {
Unavailable,
RepeatedHvp,
BatchedHvp,
Explicit,
}
impl OuterHessianMaterialization {
pub(crate) fn is_available(self) -> bool {
!matches!(self, Self::Unavailable)
}
}
#[derive(Debug, Clone)]
pub enum OuterStrategyError {
OperatorShape { reason: String },
NonFiniteHessian { reason: String },
RhoBlockShape { reason: String },
}
impl std::fmt::Display for OuterStrategyError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
OuterStrategyError::OperatorShape { reason }
| OuterStrategyError::NonFiniteHessian { reason }
| OuterStrategyError::RhoBlockShape { reason } => f.write_str(reason),
}
}
}
impl std::error::Error for OuterStrategyError {}
impl From<OuterStrategyError> for String {
fn from(err: OuterStrategyError) -> String {
err.to_string()
}
}
pub trait OuterHessianOperator: Send + Sync {
fn dim(&self) -> usize;
fn matvec(&self, v: &Array1<f64>) -> Result<Array1<f64>, String>;
fn apply_into(&self, v: &Array1<f64>, out: &mut Array1<f64>) -> Result<(), String> {
let result = self.matvec(v)?;
if result.len() != out.len() {
return Err(format!(
"outer Hessian operator matvec produced length {} but expected {}",
result.len(),
out.len()
));
}
out.assign(&result);
Ok(())
}
fn is_cheap_to_materialize(&self) -> bool {
false
}
fn materialization_capability(&self) -> OuterHessianMaterialization {
if self.is_cheap_to_materialize() {
OuterHessianMaterialization::RepeatedHvp
} else {
OuterHessianMaterialization::Unavailable
}
}
fn mul_mat(&self, factor: ArrayView2<'_, f64>) -> Result<Array2<f64>, String> {
use rayon::iter::{IntoParallelIterator, ParallelIterator};
let dim = self.dim();
if factor.nrows() != dim {
return Err(OuterStrategyError::OperatorShape {
reason: format!(
"outer Hessian operator factor row count mismatch: got {}, expected {}",
factor.nrows(),
dim
),
}
.into());
}
let m = factor.ncols();
let cols: Result<Vec<Array1<f64>>, String> = (0..m)
.into_par_iter()
.map(|j| {
let col = factor.column(j).to_owned();
let hv = self.matvec(&col)?;
if hv.len() != dim {
return Err(OuterStrategyError::OperatorShape {
reason: format!(
"outer Hessian operator matvec length mismatch: got {}, expected {}",
hv.len(),
dim
),
}
.into());
}
Ok(hv)
})
.collect();
let cols = cols?;
let mut out = Array2::<f64>::zeros((dim, m));
for (j, hv) in cols.into_iter().enumerate() {
out.column_mut(j).assign(&hv);
}
Ok(out)
}
fn materialize_dense(&self) -> Result<Array2<f64>, String> {
let dim = self.dim();
let identity = Array2::<f64>::eye(dim);
let mut dense = self.mul_mat(identity.view())?;
if dense.nrows() != dim || dense.ncols() != dim {
return Err(OuterStrategyError::OperatorShape {
reason: format!(
"outer Hessian operator mul_mat returned {}x{}, expected {}x{}",
dense.nrows(),
dense.ncols(),
dim,
dim
),
}
.into());
}
for row in 0..dim {
for col in (row + 1)..dim {
let sym = 0.5 * (dense[[row, col]] + dense[[col, row]]);
dense[[row, col]] = sym;
dense[[col, row]] = sym;
}
}
if !dense.iter().all(|value| value.is_finite()) {
return Err(OuterStrategyError::NonFiniteHessian {
reason: "outer Hessian dense materialization produced non-finite entries"
.to_string(),
}
.into());
}
Ok(dense)
}
}
pub(crate) struct RhoBlockAdditiveOuterHessian {
pub(crate) base: Arc<dyn OuterHessianOperator>,
pub(crate) rho_block: Array2<f64>,
pub(crate) dim: usize,
}
impl OuterHessianOperator for RhoBlockAdditiveOuterHessian {
fn dim(&self) -> usize {
self.dim
}
fn matvec(&self, v: &Array1<f64>) -> Result<Array1<f64>, String> {
if v.len() != self.dim {
return Err(OuterStrategyError::OperatorShape {
reason: format!(
"outer Hessian operator input length mismatch: got {}, expected {}",
v.len(),
self.dim
),
}
.into());
}
let mut out = self.base.matvec(v)?;
let k = self.rho_block.nrows();
if k > 0 {
let rho_v = v.slice(ndarray::s![..k]).to_owned();
let rho_out = self.rho_block.dot(&rho_v);
out.slice_mut(ndarray::s![..k]).scaled_add(1.0, &rho_out);
}
Ok(out)
}
fn apply_into(&self, v: &Array1<f64>, out: &mut Array1<f64>) -> Result<(), String> {
if v.len() != self.dim {
return Err(OuterStrategyError::OperatorShape {
reason: format!(
"outer Hessian operator input length mismatch: got {}, expected {}",
v.len(),
self.dim
),
}
.into());
}
if out.len() != self.dim {
return Err(OuterStrategyError::OperatorShape {
reason: format!(
"outer Hessian apply_into output length mismatch: got {}, expected {}",
out.len(),
self.dim
),
}
.into());
}
self.base.apply_into(v, out)?;
let k = self.rho_block.nrows();
if k > 0 {
let v_top = v.slice(ndarray::s![..k]);
for i in 0..k {
out[i] += self.rho_block.row(i).dot(&v_top);
}
}
Ok(())
}
fn mul_mat(&self, factor: ArrayView2<'_, f64>) -> Result<Array2<f64>, String> {
let mut out = self.base.mul_mat(factor)?;
let k = self.rho_block.nrows();
if k > 0 {
if k > out.nrows() {
return Err(OuterStrategyError::RhoBlockShape {
reason: format!(
"rho-block Hessian update shape mismatch: rho_block is {}x{}, mul_mat output has {} rows",
self.rho_block.nrows(),
self.rho_block.ncols(),
out.nrows()
),
}
.into());
}
let factor_top = factor.slice(ndarray::s![..k, ..]);
let rho_contrib = self.rho_block.dot(&factor_top);
out.slice_mut(ndarray::s![..k, ..])
.scaled_add(1.0, &rho_contrib);
}
Ok(out)
}
fn is_cheap_to_materialize(&self) -> bool {
self.base.is_cheap_to_materialize()
}
fn materialization_capability(&self) -> OuterHessianMaterialization {
self.base.materialization_capability()
}
}
pub(crate) const OUTER_HVP_MATERIALIZE_MAX_DIM: usize = 64;
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum Derivative {
Analytic,
Unavailable,
}
#[derive(Clone, Copy, Debug, PartialEq)]
pub enum DeclaredHessianForm {
Dense,
Operator {
materialization: OuterHessianMaterialization,
estimated_materialization_cost: Option<f64>,
},
Either,
Unavailable,
}
impl DeclaredHessianForm {
pub const fn is_analytic(self) -> bool {
!matches!(self, DeclaredHessianForm::Unavailable)
}
pub const fn is_operator_only(self) -> bool {
matches!(self, DeclaredHessianForm::Operator { .. })
}
pub const fn is_dense_only(self) -> bool {
matches!(self, DeclaredHessianForm::Dense)
}
}