use super::*;
pub trait HessianDerivativeProvider: Send + Sync {
fn hessian_derivative_correction(
&self,
v_k: &Array1<f64>,
) -> Result<Option<Array2<f64>>, String>;
fn hessian_derivative_correction_result(
&self,
v_k: &Array1<f64>,
) -> Result<Option<DriftDerivResult>, String> {
Ok(self
.hessian_derivative_correction(v_k)?
.map(DriftDerivResult::Dense))
}
fn hessian_derivative_corrections_result(
&self,
v_ks: &[Array1<f64>],
) -> Result<Vec<Option<DriftDerivResult>>, String> {
v_ks.iter()
.map(|v_k| self.hessian_derivative_correction_result(v_k))
.collect()
}
fn has_batched_hessian_derivative_corrections(&self) -> bool {
false
}
fn hessian_second_derivative_correction(
&self,
arr: &Array1<f64>,
arr2: &Array1<f64>,
arr3: &Array1<f64>,
) -> Result<Option<Array2<f64>>, String> {
assert!(arr.iter().all(|v| !v.is_nan()));
assert!(arr2.iter().all(|v| !v.is_nan()));
assert!(arr3.iter().all(|v| !v.is_nan()));
if self.has_corrections() {
Err(
"HessianDerivativeProvider reports first-order corrections but does not implement second-order correction"
.to_string(),
)
} else {
Ok(None)
}
}
fn hessian_second_derivative_correction_result(
&self,
v_k: &Array1<f64>,
v_l: &Array1<f64>,
u_kl: &Array1<f64>,
) -> Result<Option<DriftDerivResult>, String> {
Ok(self
.hessian_second_derivative_correction(v_k, v_l, u_kl)?
.map(DriftDerivResult::Dense))
}
fn hessian_second_derivative_corrections_result(
&self,
triples: &[(Array1<f64>, Array1<f64>, Array1<f64>)],
) -> Result<Vec<Option<DriftDerivResult>>, String> {
triples
.iter()
.map(|(v_k, v_l, u_kl)| {
self.hessian_second_derivative_correction_result(v_k, v_l, u_kl)
})
.collect()
}
fn has_batched_hessian_second_derivative_corrections(&self) -> bool {
false
}
fn has_corrections(&self) -> bool;
fn scalar_glm_ingredients(&self) -> Option<ScalarGlmIngredients<'_>> {
None
}
fn outer_hessian_derivative_kernel(&self) -> Option<OuterHessianDerivativeKernel> {
self.scalar_glm_ingredients()
.map(OuterHessianDerivativeKernel::from_scalar_glm)
}
fn family_outer_hessian_operator(
&self,
) -> Option<Arc<dyn crate::solver::rho_optimizer::OuterHessianOperator>> {
None
}
}
pub struct ScalarGlmIngredients<'a> {
pub c_array: &'a Array1<f64>,
pub d_array: Option<&'a Array1<f64>>,
pub x: &'a DesignMatrix,
}
#[derive(Clone)]
pub enum OuterHessianDerivativeKernel {
Gaussian,
ScalarGlm {
c_array: Array1<f64>,
d_array: Option<Array1<f64>>,
x: DesignMatrix,
},
Callback {
first: Arc<dyn Fn(&Array1<f64>) -> Result<Option<DriftDerivResult>, String> + Send + Sync>,
second: Arc<
dyn Fn(&Array1<f64>, &Array1<f64>) -> Result<Option<DriftDerivResult>, String>
+ Send
+ Sync,
>,
},
}
impl OuterHessianDerivativeKernel {
pub(crate) fn from_scalar_glm(ingredients: ScalarGlmIngredients<'_>) -> Self {
Self::ScalarGlm {
c_array: ingredients.c_array.clone(),
d_array: ingredients.d_array.cloned(),
x: ingredients.x.clone(),
}
}
}
pub struct GaussianDerivatives;
impl HessianDerivativeProvider for GaussianDerivatives {
fn outer_hessian_derivative_kernel(&self) -> Option<OuterHessianDerivativeKernel> {
Some(OuterHessianDerivativeKernel::Gaussian)
}
fn hessian_derivative_correction(
&self,
arr: &Array1<f64>,
) -> Result<Option<Array2<f64>>, String> {
assert!(arr.iter().all(|v| !v.is_nan()));
Ok(None)
}
fn has_corrections(&self) -> bool {
false
}
}
pub struct SinglePredictorGlmDerivatives {
pub c_array: Array1<f64>,
pub d_array: Option<Array1<f64>>,
pub hessian_weights: Array1<f64>,
pub x_transformed: DesignMatrix,
}
impl HessianDerivativeProvider for SinglePredictorGlmDerivatives {
fn hessian_derivative_correction(
&self,
v_k: &Array1<f64>,
) -> Result<Option<Array2<f64>>, String> {
let x_v = self.x_transformed.matrixvectormultiply(v_k);
let crate::pirls::DirectionalWorkingCurvature::Diagonal(mut neg_c_xv) =
crate::pirls::directionalworking_curvature_from_c_array(
&self.c_array,
&self.hessian_weights,
&x_v,
);
neg_c_xv.mapv_inplace(|value| -value);
let result = self
.x_transformed
.xt_diag_x_signed_op(SignedWeightsView::from_array(&neg_c_xv))
.map_err(|e| format!("hessian_derivative_correction xtwx: {e}"))?;
Ok(Some(result))
}
fn hessian_derivative_correction_result(
&self,
v_k: &Array1<f64>,
) -> Result<Option<DriftDerivResult>, String> {
let x_v = self.x_transformed.matrixvectormultiply(v_k);
let crate::pirls::DirectionalWorkingCurvature::Diagonal(mut neg_c_xv) =
crate::pirls::directionalworking_curvature_from_c_array(
&self.c_array,
&self.hessian_weights,
&x_v,
);
neg_c_xv.mapv_inplace(|value| -value);
Ok(Some(DriftDerivResult::Operator(Arc::new(
GlmCurvatureCorrectionOperator {
x_design: self.x_transformed.clone(),
neg_c_xv,
p: self.x_transformed.ncols(),
},
))))
}
fn hessian_second_derivative_correction(
&self,
v_k: &Array1<f64>,
v_l: &Array1<f64>,
u_kl: &Array1<f64>,
) -> Result<Option<Array2<f64>>, String> {
let x_vk = self.x_transformed.matrixvectormultiply(v_k);
let x_vl = self.x_transformed.matrixvectormultiply(v_l);
let x_ukl = self.x_transformed.matrixvectormultiply(u_kl);
let n = self.x_transformed.nrows();
let mut weights = Array1::zeros(n);
let crate::pirls::DirectionalWorkingCurvature::Diagonal(first_weights) =
crate::pirls::directionalworking_curvature_from_c_array(
&self.c_array,
&self.hessian_weights,
&x_ukl,
);
weights.assign(&first_weights);
if let Some(ref d_array) = self.d_array {
Zip::from(&mut weights)
.and(d_array)
.and(&x_vk)
.and(&x_vl)
.and(&self.hessian_weights)
.par_for_each(|w, &d, &xvk, &xvl, &h| {
if h > 0.0 {
let delta = d * xvk * xvl;
if delta.is_finite() {
*w += delta;
}
}
});
}
let result = self
.x_transformed
.xt_diag_x_signed_op(SignedWeightsView::from_array(&weights))
.map_err(|e| format!("hessian_second_derivative_correction xtwx: {e}"))?;
Ok(Some(result))
}
fn has_corrections(&self) -> bool {
true
}
fn scalar_glm_ingredients(&self) -> Option<ScalarGlmIngredients<'_>> {
Some(ScalarGlmIngredients {
c_array: &self.c_array,
d_array: self.d_array.as_ref(),
x: &self.x_transformed,
})
}
}
pub struct FirthAwareGlmDerivatives {
pub(crate) base: SinglePredictorGlmDerivatives,
pub(crate) firth_op: std::sync::Arc<super::super::FirthDenseOperator>,
}
impl HessianDerivativeProvider for FirthAwareGlmDerivatives {
fn hessian_derivative_correction(
&self,
v_k: &Array1<f64>,
) -> Result<Option<Array2<f64>>, String> {
let base_corr = self.base.hessian_derivative_correction(v_k)?;
let deta_k: Array1<f64> =
crate::faer_ndarray::fast_av(&self.firth_op.x_dense, v_k).mapv(|v| -v);
let dir_k = self.firth_op.direction_from_deta(deta_k);
let firth_corr = self.firth_op.hphi_direction(&dir_k);
match base_corr {
Some(mut bc) => {
bc -= &firth_corr;
Ok(Some(bc))
}
None => Ok(Some(-firth_corr)),
}
}
fn hessian_second_derivative_correction(
&self,
v_k: &Array1<f64>,
v_l: &Array1<f64>,
u_kl: &Array1<f64>,
) -> Result<Option<Array2<f64>>, String> {
let base_corr = self
.base
.hessian_second_derivative_correction(v_k, v_l, u_kl)?;
let deta_kl: Array1<f64> = crate::faer_ndarray::fast_av(&self.firth_op.x_dense, u_kl);
let dir_kl = self.firth_op.direction_from_deta(deta_kl);
let firth_first = self.firth_op.hphi_direction(&dir_kl);
let deta_k: Array1<f64> =
crate::faer_ndarray::fast_av(&self.firth_op.x_dense, v_k).mapv(|v| -v);
let dir_k = self.firth_op.direction_from_deta(deta_k);
let deta_l: Array1<f64> =
crate::faer_ndarray::fast_av(&self.firth_op.x_dense, v_l).mapv(|v| -v);
let dir_l = self.firth_op.direction_from_deta(deta_l);
let p = v_k.len();
let eye = Array2::<f64>::eye(p);
let firth_second = self
.firth_op
.hphisecond_direction_apply(&dir_k, &dir_l, &eye);
let mut result = match base_corr {
Some(bc) => bc,
None => Array2::zeros((p, p)),
};
result -= &firth_first;
result -= &firth_second;
Ok(Some(result))
}
fn hessian_derivative_correction_result(
&self,
v_k: &Array1<f64>,
) -> Result<Option<DriftDerivResult>, String> {
let base = self.base.hessian_derivative_correction_result(v_k)?;
let deta_k: Array1<f64> =
crate::faer_ndarray::fast_av(&self.firth_op.x_dense, v_k).mapv(|v| -v);
let dir_k = self.firth_op.direction_from_deta(deta_k);
let neg_firth_corr = -self.firth_op.hphi_direction(&dir_k);
match base {
Some(DriftDerivResult::Operator(operator)) => Ok(Some(DriftDerivResult::Operator(
Arc::new(CompositeHyperOperator {
dense: Some(neg_firth_corr),
operators: vec![operator],
dim_hint: self.base.x_transformed.ncols(),
}),
))),
Some(DriftDerivResult::Dense(mut dense)) => {
dense += &neg_firth_corr;
Ok(Some(DriftDerivResult::Dense(dense)))
}
None => Ok(Some(DriftDerivResult::Dense(neg_firth_corr))),
}
}
fn has_corrections(&self) -> bool {
true
}
fn scalar_glm_ingredients(&self) -> Option<ScalarGlmIngredients<'_>> {
None
}
}
#[derive(Clone)]
pub struct ExactJeffreysTerm {
pub(crate) operator: Option<std::sync::Arc<super::super::FirthDenseOperator>>,
pub(crate) value_override: Option<f64>,
}
impl ExactJeffreysTerm {
pub(crate) fn new(operator: std::sync::Arc<super::super::FirthDenseOperator>) -> Self {
Self {
operator: Some(operator),
value_override: None,
}
}
pub(crate) fn value_only(phi: f64) -> Self {
Self {
operator: None,
value_override: Some(phi),
}
}
pub(crate) fn with_projected_value(
operator: std::sync::Arc<super::super::FirthDenseOperator>,
projected_value: f64,
) -> Self {
Self {
operator: Some(operator),
value_override: Some(projected_value),
}
}
#[inline]
pub(crate) fn value(&self) -> f64 {
self.value_override.unwrap_or_else(|| {
self.operator
.as_ref()
.map_or(0.0, |operator| operator.jeffreys_logdet())
})
}
#[inline]
pub(crate) fn operator_arc(&self) -> Option<std::sync::Arc<super::super::FirthDenseOperator>> {
self.operator.as_ref().map(std::sync::Arc::clone)
}
}
#[derive(Clone, Debug)]
pub struct BarrierConfig {
pub tau: f64,
pub constrained_indices: Vec<usize>,
pub lower_bounds: Vec<f64>,
pub bound_signs: Vec<f64>,
}
impl BarrierConfig {
pub fn from_constraints(
constraints: Option<&crate::pirls::LinearInequalityConstraints>,
) -> Option<Self> {
const SIMPLE_BOUND_ENTRY_TOL: f64 = 1e-14;
const DEFAULT_BARRIER_TAU: f64 = 1e-6;
let constraints = constraints?;
let mut indices = Vec::new();
let mut lower_bounds = Vec::new();
let mut bound_signs = Vec::new();
for i in 0..constraints.a.nrows() {
let row = constraints.a.row(i);
let mut single_col = None;
let mut single_sign = 0.0_f64;
let mut is_simple = true;
for (j, &val) in row.iter().enumerate() {
if val.abs() < SIMPLE_BOUND_ENTRY_TOL {
continue;
}
if ((val - 1.0).abs() < SIMPLE_BOUND_ENTRY_TOL
|| (val + 1.0).abs() < SIMPLE_BOUND_ENTRY_TOL)
&& single_col.is_none()
{
single_col = Some(j);
single_sign = if val > 0.0 { 1.0 } else { -1.0 };
} else {
is_simple = false;
break;
}
}
if is_simple && let Some(col) = single_col {
indices.push(col);
lower_bounds.push(constraints.b[i]);
bound_signs.push(single_sign);
}
}
if indices.is_empty() {
return None;
}
Some(BarrierConfig {
tau: DEFAULT_BARRIER_TAU,
constrained_indices: indices,
lower_bounds,
bound_signs,
})
}
pub fn slacks(&self, beta: &Array1<f64>) -> Option<Vec<f64>> {
let mut slacks = Vec::with_capacity(self.constrained_indices.len());
for (ci, &idx) in self.constrained_indices.iter().enumerate() {
let sign = self.bound_signs[ci];
let delta = sign * beta[idx] - self.lower_bounds[ci];
if delta <= 0.0 {
return None;
}
slacks.push(delta);
}
Some(slacks)
}
pub fn add_barrier_hessian_diagonal(
&self,
h: &mut Array2<f64>,
beta: &Array1<f64>,
) -> Result<(), String> {
let slacks = self
.slacks(beta)
.ok_or_else(|| "Barrier: infeasible point (slack ≤ 0)".to_string())?;
for (ci, &idx) in self.constrained_indices.iter().enumerate() {
h[[idx, idx]] += self.tau / (slacks[ci] * slacks[ci]);
}
Ok(())
}
pub fn barrier_cost(&self, beta: &Array1<f64>) -> f64 {
let mut total = 0.0_f64;
for (ci, &idx) in self.constrained_indices.iter().enumerate() {
let sign = self.bound_signs[ci];
let delta = sign * beta[idx] - self.lower_bounds[ci];
if delta <= 0.0 {
return f64::INFINITY;
}
total += delta.ln();
}
-self.tau * total
}
pub fn barrier_curvature_locally_concentrated(
&self,
beta: &Array1<f64>,
ratio: f64,
saturation_threshold: f64,
) -> bool {
let Some(mut slacks) = self.slacks(beta) else {
return true; };
if slacks.is_empty() {
return false;
}
let min_slack = slacks.iter().copied().fold(f64::INFINITY, f64::min);
if min_slack > 0.0 && min_slack.is_finite() && saturation_threshold.is_finite() {
let max_barrier_curv = self.tau / (min_slack * min_slack);
if max_barrier_curv >= saturation_threshold {
return true;
}
}
slacks.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
let median = if slacks.len() % 2 == 1 {
slacks[slacks.len() / 2]
} else {
let mid = slacks.len() / 2;
0.5 * (slacks[mid - 1] + slacks[mid])
};
if !median.is_finite() || median <= 0.0 {
return true;
}
min_slack < ratio * median
}
pub fn barrier_curvature_is_significant(
&self,
beta: &Array1<f64>,
ref_diag: f64,
threshold: f64,
) -> bool {
let Some(slacks) = self.slacks(beta) else {
return true; };
let max_barrier_curv = slacks
.iter()
.map(|&d| self.tau / (d * d))
.fold(0.0_f64, f64::max);
max_barrier_curv > threshold * ref_diag
}
}
pub struct BarrierDerivativeProvider<'a> {
pub(crate) inner: &'a dyn HessianDerivativeProvider,
pub(crate) tau: f64,
pub(crate) constrained_indices: &'a [usize],
pub(crate) bound_signs: &'a [f64],
pub(crate) slacks: Vec<f64>,
pub(crate) p: usize,
}
impl<'a> BarrierDerivativeProvider<'a> {
pub fn new(
inner: &'a dyn HessianDerivativeProvider,
config: &'a BarrierConfig,
beta: &Array1<f64>,
) -> Result<Self, String> {
let slacks = config
.slacks(beta)
.ok_or_else(|| "BarrierDerivativeProvider: infeasible point".to_string())?;
Ok(Self {
inner,
tau: config.tau,
constrained_indices: &config.constrained_indices,
bound_signs: &config.bound_signs,
slacks,
p: beta.len(),
})
}
pub(crate) fn barrier_correction(&self, u: &Array1<f64>) -> Array2<f64> {
let mut result = Array2::zeros((self.p, self.p));
for (ci, &idx) in self.constrained_indices.iter().enumerate() {
let inv_cube = 1.0 / (self.slacks[ci].powi(3));
result[[idx, idx]] = -2.0 * self.tau * self.bound_signs[ci] * u[idx] * inv_cube;
}
result
}
pub(crate) fn barrier_second_correction(
&self,
u: &Array1<f64>,
v: &Array1<f64>,
) -> Array2<f64> {
let mut result = Array2::zeros((self.p, self.p));
for (ci, &idx) in self.constrained_indices.iter().enumerate() {
let inv_4 = 1.0 / (self.slacks[ci].powi(4));
result[[idx, idx]] = 6.0 * self.tau * u[idx] * v[idx] * inv_4;
}
result
}
}
impl HessianDerivativeProvider for BarrierDerivativeProvider<'_> {
fn hessian_derivative_correction(
&self,
v_k: &Array1<f64>,
) -> Result<Option<Array2<f64>>, String> {
let neg_v_k = v_k.mapv(|x| -x);
let barrier_corr = self.barrier_correction(&neg_v_k);
match self.inner.hessian_derivative_correction(v_k)? {
Some(mut ic) => {
ic += &barrier_corr;
Ok(Some(ic))
}
None => Ok(Some(barrier_corr)),
}
}
fn hessian_derivative_correction_result(
&self,
v_k: &Array1<f64>,
) -> Result<Option<DriftDerivResult>, String> {
let neg_v_k = v_k.mapv(|x| -x);
let barrier_corr = self.barrier_correction(&neg_v_k);
match self.inner.hessian_derivative_correction_result(v_k)? {
Some(DriftDerivResult::Dense(mut dense)) => {
dense += &barrier_corr;
Ok(Some(DriftDerivResult::Dense(dense)))
}
Some(DriftDerivResult::Operator(operator)) => Ok(Some(DriftDerivResult::Operator(
Arc::new(CompositeHyperOperator {
dense: Some(barrier_corr),
operators: vec![operator],
dim_hint: self.p,
}),
))),
None => Ok(Some(DriftDerivResult::Dense(barrier_corr))),
}
}
fn hessian_second_derivative_correction(
&self,
v_k: &Array1<f64>,
v_l: &Array1<f64>,
u_kl: &Array1<f64>,
) -> Result<Option<Array2<f64>>, String> {
let barrier_total =
&self.barrier_correction(u_kl) + &self.barrier_second_correction(v_k, v_l);
match self
.inner
.hessian_second_derivative_correction(v_k, v_l, u_kl)?
{
Some(mut ic) => {
ic += &barrier_total;
Ok(Some(ic))
}
None => Ok(Some(barrier_total)),
}
}
fn hessian_second_derivative_correction_result(
&self,
v_k: &Array1<f64>,
v_l: &Array1<f64>,
u_kl: &Array1<f64>,
) -> Result<Option<DriftDerivResult>, String> {
let barrier_total =
&self.barrier_correction(u_kl) + &self.barrier_second_correction(v_k, v_l);
match self
.inner
.hessian_second_derivative_correction_result(v_k, v_l, u_kl)?
{
Some(DriftDerivResult::Dense(mut dense)) => {
dense += &barrier_total;
Ok(Some(DriftDerivResult::Dense(dense)))
}
Some(DriftDerivResult::Operator(operator)) => Ok(Some(DriftDerivResult::Operator(
Arc::new(CompositeHyperOperator {
dense: Some(barrier_total),
operators: vec![operator],
dim_hint: self.p,
}),
))),
None => Ok(Some(DriftDerivResult::Dense(barrier_total))),
}
}
fn has_corrections(&self) -> bool {
true
}
fn scalar_glm_ingredients(&self) -> Option<ScalarGlmIngredients<'_>> {
None
}
}