use dense_projection::{dense_projected_matrix, dense_trace_projected_factor};
fn reml_contract_panic(message: impl Into<String>) -> ! {
std::panic::panic_any(message.into())
}
#[derive(Debug, Clone)]
pub enum RemlError {
DimensionMismatch { reason: String },
NonFiniteValue { reason: String },
InvalidKernelMode { reason: String },
ContractViolation { reason: String },
}
impl std::fmt::Display for RemlError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
RemlError::DimensionMismatch { reason }
| RemlError::NonFiniteValue { reason }
| RemlError::InvalidKernelMode { reason }
| RemlError::ContractViolation { reason } => f.write_str(reason),
}
}
}
impl std::error::Error for RemlError {}
impl From<RemlError> for String {
fn from(err: RemlError) -> String {
err.to_string()
}
}
pub use crate::test_support::debug_stash;
#[derive(Debug, Default)]
pub struct StochasticTraceState {
pub monotone_probe_floor: usize,
pub cg_warm_starts: HashMap<u64, Array1<f64>>,
pub solve_rel_tol_override: Option<f64>,
pub last_linear_residual_norm: Option<f64>,
pub last_probe_sigma_sq: Option<f64>,
pub last_probe_count: usize,
}
const HUTCHPP_TRACE_MIN_DIM: usize = 128;
fn hutchpp_config_for_dim(dim: usize) -> StochasticTraceConfig {
const SKETCH_DIM_PER: usize = 32;
const SKETCH_DIM_MIN: usize = 4;
const SKETCH_DIM_MAX: usize = 16;
const PROBES_PER_SKETCH: usize = 4;
const PROBES_MAX_FLOOR: usize = 32;
const PROBES_MIN_FLOOR: usize = 8;
let sketch = (dim / SKETCH_DIM_PER).clamp(SKETCH_DIM_MIN, SKETCH_DIM_MAX);
let mut config = StochasticTraceConfig::default();
config.hutchpp_sketch_dim = Some(sketch);
config.n_probes_max = (sketch * PROBES_PER_SKETCH).max(PROBES_MAX_FLOOR);
config.n_probes_min = sketch.max(PROBES_MIN_FLOOR);
config
}
pub trait HessianOperator: Send + Sync {
fn logdet(&self) -> f64;
fn trace_hinv_product(&self, a: &Array2<f64>) -> f64;
fn as_exact_dense_spectral(&self) -> Option<&DenseSpectralOperator> {
None
}
fn assemble_h_dense_for_tangent_projection(&self) -> Result<Array2<f64>, String> {
Err("backend does not support tangent projection".to_string())
}
fn trace_hinv_operator(&self, op: &dyn HyperOperator) -> f64 {
if op.is_implicit() && self.dim() >= HUTCHPP_TRACE_MIN_DIM {
let config = hutchpp_config_for_dim(self.dim());
return hutchpp_estimate_trace_hinv_operator(self, op, &config);
}
if op.is_implicit() {
log::warn!(
"trace_hinv_operator: materializing implicit HyperOperator — \
backend should provide a matrix-free override"
);
}
self.trace_hinv_product(&op.to_dense())
}
fn trace_hinv_h_k(
&self,
a_k: &Array2<f64>,
third_deriv_correction: Option<&Array2<f64>>,
) -> f64 {
let base = self.trace_hinv_product(a_k);
match third_deriv_correction {
Some(c) => base + self.trace_hinv_product(c),
None => base,
}
}
fn solve(&self, rhs: &Array1<f64>) -> Array1<f64>;
fn solve_multi(&self, rhs: &Array2<f64>) -> Array2<f64>;
fn stochastic_trace_solve(&self, rhs: &Array1<f64>, rel_tol: f64) -> Array1<f64> {
assert!(
rel_tol.is_finite() && rel_tol > 0.0,
"stochastic trace solve tolerance must be positive and finite"
);
self.solve(rhs)
}
fn stochastic_trace_solve_for_probe(
&self,
rhs: &Array1<f64>,
rel_tol: f64,
probe_id: u64,
state: Option<&Arc<Mutex<StochasticTraceState>>>,
) -> Array1<f64> {
if let Some(state_arc) = state
&& let Ok(mut guard) = state_arc.lock()
{
guard.cg_warm_starts.remove(&probe_id);
}
self.stochastic_trace_solve(rhs, rel_tol)
}
fn stochastic_trace_solve_multi(&self, rhs: &Array2<f64>, rel_tol: f64) -> Array2<f64> {
assert!(
rel_tol.is_finite() && rel_tol > 0.0,
"stochastic trace multi-solve tolerance must be positive and finite"
);
self.solve_multi(rhs)
}
fn has_matrix_free_trace_cg_operator(&self) -> bool {
false
}
fn trace_hinv_product_cross(&self, a: &Array2<f64>, b: &Array2<f64>) -> f64 {
let solved_a = self.solve_multi(a);
if std::ptr::eq(a, b) {
return trace_matrix_product(&solved_a, &solved_a);
}
let solved_b = self.solve_multi(b);
trace_matrix_product(&solved_a, &solved_b)
}
fn trace_hinv_matrix_operator_cross(
&self,
matrix: &Array2<f64>,
op: &dyn HyperOperator,
) -> f64 {
if op.is_implicit() && self.dim() >= HUTCHPP_TRACE_MIN_DIM {
let config = hutchpp_config_for_dim(self.dim());
let lhs = DenseMatrixHyperOperator {
matrix: matrix.clone(),
};
return hutchpp_estimate_trace_hinv_operator_cross(self, &lhs, op, &config);
}
if op.is_implicit() {
log::warn!(
"trace_hinv_matrix_operator_cross: materializing implicit HyperOperator — \
backend should provide a matrix-free override"
);
}
self.trace_hinv_product_cross(matrix, &op.to_dense())
}
fn trace_hinv_operator_cross(
&self,
left: &dyn HyperOperator,
right: &dyn HyperOperator,
) -> f64 {
let l_implicit = left.is_implicit();
let r_implicit = right.is_implicit();
if (l_implicit || r_implicit) && self.dim() >= HUTCHPP_TRACE_MIN_DIM {
let config = hutchpp_config_for_dim(self.dim());
if std::ptr::eq(
left as *const dyn HyperOperator as *const (),
right as *const dyn HyperOperator as *const (),
) {
return hutchpp_estimate_trace_hinv_op_squared(self, left, &config);
}
return hutchpp_estimate_trace_hinv_operator_cross(self, left, right, &config);
}
if l_implicit || r_implicit {
log::warn!(
"trace_hinv_operator_cross: materializing implicit HyperOperator(s) — \
backend should provide a matrix-free override"
);
}
self.trace_hinv_product_cross(&left.to_dense(), &right.to_dense())
}
fn trace_logdet_gradient(&self, a: &Array2<f64>) -> f64 {
self.trace_hinv_product(a)
}
fn xt_logdet_kernel_x_diagonal(&self, x: &DesignMatrix) -> Array1<f64> {
assert!(self.logdet_traces_match_hinv_kernel());
let n = x.nrows();
let p = x.ncols();
let block = {
const TARGET_CHUNK_FLOATS: usize = 1 << 16;
(TARGET_CHUNK_FLOATS / p.max(1)).clamp(1, n.max(1))
};
let mut h = Array1::<f64>::zeros(n);
let mut start = 0usize;
while start < n {
let end = (start + block).min(n);
let rows = x.try_row_chunk(start..end).unwrap_or_else(|err| {
reml_contract_panic(format!(
"xt_logdet_kernel_x_diagonal: row chunk failed: {err}"
))
});
let chunk_t = rows.t().to_owned();
let z_chunk = self.solve_multi(&chunk_t);
for (i, (row, z_col)) in rows
.outer_iter()
.zip(z_chunk.columns().into_iter())
.enumerate()
{
let mut acc = 0.0;
for (row_value, z_value) in row.iter().copied().zip(z_col.iter().copied()) {
acc += row_value * z_value;
}
h[start + i] = acc;
}
start = end;
}
h
}
fn trace_logdet_operator(&self, op: &dyn HyperOperator) -> f64 {
if op.is_implicit()
&& self.dim() >= HUTCHPP_TRACE_MIN_DIM
&& self.logdet_traces_match_hinv_kernel()
{
let config = hutchpp_config_for_dim(self.dim());
return hutchpp_estimate_trace_hinv_operator(self, op, &config);
}
if op.is_implicit() {
log::warn!(
"trace_logdet_operator: materializing implicit HyperOperator — \
backend should provide a matrix-free override"
);
}
self.trace_logdet_gradient(&op.to_dense())
}
fn trace_logdet_h_k(
&self,
a_k: &Array2<f64>,
third_deriv_correction: Option<&Array2<f64>>,
) -> f64 {
let base = self.trace_logdet_gradient(a_k);
match third_deriv_correction {
Some(c) => base + self.trace_logdet_gradient(c),
None => base,
}
}
fn trace_logdet_h_k_operator(
&self,
b_k: &dyn HyperOperator,
third_deriv_correction: Option<&Array2<f64>>,
) -> f64 {
let base = self.trace_logdet_operator(b_k);
match third_deriv_correction {
Some(c) => base + self.trace_logdet_gradient(c),
None => base,
}
}
fn trace_logdet_block_local(
&self,
block: &Array2<f64>,
scale: f64,
start: usize,
end: usize,
) -> f64 {
let p = self.dim();
let mut full = Array2::<f64>::zeros((p, p));
let bs = end - start;
for i in 0..bs {
for j in 0..bs {
full[[start + i, start + j]] = scale * block[[i, j]];
}
}
self.trace_logdet_gradient(&full)
}
fn trace_hinv_block_local(
&self,
block: &Array2<f64>,
scale: f64,
start: usize,
end: usize,
) -> f64 {
let p = self.dim();
let mut full = Array2::<f64>::zeros((p, p));
let bs = end - start;
for i in 0..bs {
for j in 0..bs {
full[[start + i, start + j]] = scale * block[[i, j]];
}
}
self.trace_hinv_product(&full)
}
fn trace_hinv_block_local_cross(
&self,
block: &Array2<f64>,
scale: f64,
start: usize,
end: usize,
) -> f64 {
let p = self.dim();
let bs = end - start;
let mut full = Array2::<f64>::zeros((p, p));
for i in 0..bs {
for j in 0..bs {
full[[start + i, start + j]] = scale * block[[i, j]];
}
}
self.trace_hinv_product_cross(&full, &full)
}
fn trace_logdet_hessian_cross(&self, h_i: &Array2<f64>, h_j: &Array2<f64>) -> f64 {
let y_i = self.solve_multi(h_i);
if std::ptr::eq(h_i, h_j) {
return -trace_matrix_product(&y_i, &y_i);
}
let y_j = self.solve_multi(h_j);
-trace_matrix_product(&y_j, &y_i)
}
fn trace_logdet_hessian_cross_matrix_operator(
&self,
h_i: &Array2<f64>,
h_j: &dyn HyperOperator,
) -> f64 {
self.trace_logdet_hessian_cross(h_i, &h_j.to_dense())
}
fn trace_logdet_hessian_cross_operator(
&self,
h_i: &dyn HyperOperator,
h_j: &dyn HyperOperator,
) -> f64 {
self.trace_logdet_hessian_cross(&h_i.to_dense(), &h_j.to_dense())
}
fn trace_logdet_hessian_crosses(&self, matrices: &[&Array2<f64>]) -> Array2<f64> {
let n = matrices.len();
let mut out = Array2::<f64>::zeros((n, n));
for i in 0..n {
for j in i..n {
let value = self.trace_logdet_hessian_cross(matrices[i], matrices[j]);
out[[i, j]] = value;
out[[j, i]] = value;
}
}
out
}
fn active_rank(&self) -> usize;
fn dim(&self) -> usize;
fn is_dense(&self) -> bool {
false
}
fn prefers_stochastic_trace_estimation(&self) -> bool {
self.is_dense()
}
fn logdet_traces_match_hinv_kernel(&self) -> bool {
true
}
fn as_dense_spectral(&self) -> Option<&DenseSpectralOperator> {
None
}
}
pub fn hessian_operator_geometric_scale(op: &dyn HessianOperator) -> Option<f64> {
let rank = op.active_rank();
if rank == 0 {
return None;
}
let logdet = op.logdet();
if !logdet.is_finite() {
return None;
}
let scale = (logdet / rank as f64).exp();
if scale.is_finite() && scale > 0.0 {
Some(scale)
} else {
None
}
}
pub trait HessianDerivativeProvider: Send + Sync {
fn hessian_derivative_correction(
&self,
v_k: &Array1<f64>,
) -> Result<Option<Array2<f64>>, String>;
fn hessian_derivative_correction_result(
&self,
v_k: &Array1<f64>,
) -> Result<Option<DriftDerivResult>, String> {
Ok(self
.hessian_derivative_correction(v_k)?
.map(DriftDerivResult::Dense))
}
fn hessian_derivative_corrections_result(
&self,
v_ks: &[Array1<f64>],
) -> Result<Vec<Option<DriftDerivResult>>, String> {
v_ks.iter()
.map(|v_k| self.hessian_derivative_correction_result(v_k))
.collect()
}
fn has_batched_hessian_derivative_corrections(&self) -> bool {
false
}
fn hessian_second_derivative_correction(
&self,
arr: &Array1<f64>,
arr2: &Array1<f64>,
arr3: &Array1<f64>,
) -> Result<Option<Array2<f64>>, String> {
assert!(arr.iter().all(|v| !v.is_nan()));
assert!(arr2.iter().all(|v| !v.is_nan()));
assert!(arr3.iter().all(|v| !v.is_nan()));
if self.has_corrections() {
Err(
"HessianDerivativeProvider reports first-order corrections but does not implement second-order correction"
.to_string(),
)
} else {
Ok(None)
}
}
fn hessian_second_derivative_correction_result(
&self,
v_k: &Array1<f64>,
v_l: &Array1<f64>,
u_kl: &Array1<f64>,
) -> Result<Option<DriftDerivResult>, String> {
Ok(self
.hessian_second_derivative_correction(v_k, v_l, u_kl)?
.map(DriftDerivResult::Dense))
}
fn hessian_second_derivative_corrections_result(
&self,
triples: &[(Array1<f64>, Array1<f64>, Array1<f64>)],
) -> Result<Vec<Option<DriftDerivResult>>, String> {
triples
.iter()
.map(|(v_k, v_l, u_kl)| {
self.hessian_second_derivative_correction_result(v_k, v_l, u_kl)
})
.collect()
}
fn has_batched_hessian_second_derivative_corrections(&self) -> bool {
false
}
fn has_corrections(&self) -> bool;
fn scalar_glm_ingredients(&self) -> Option<ScalarGlmIngredients<'_>> {
None
}
fn outer_hessian_derivative_kernel(&self) -> Option<OuterHessianDerivativeKernel> {
self.scalar_glm_ingredients()
.map(OuterHessianDerivativeKernel::from_scalar_glm)
}
fn family_outer_hessian_operator(
&self,
) -> Option<Arc<dyn crate::solver::outer_strategy::OuterHessianOperator>> {
None
}
}
pub struct ScalarGlmIngredients<'a> {
pub c_array: &'a Array1<f64>,
pub d_array: Option<&'a Array1<f64>>,
pub x: &'a DesignMatrix,
}
#[derive(Clone)]
pub enum OuterHessianDerivativeKernel {
Gaussian,
ScalarGlm {
c_array: Array1<f64>,
d_array: Option<Array1<f64>>,
x: DesignMatrix,
},
Callback {
first: Arc<dyn Fn(&Array1<f64>) -> Result<Option<DriftDerivResult>, String> + Send + Sync>,
second: Arc<
dyn Fn(&Array1<f64>, &Array1<f64>) -> Result<Option<DriftDerivResult>, String>
+ Send
+ Sync,
>,
},
}
impl OuterHessianDerivativeKernel {
fn from_scalar_glm(ingredients: ScalarGlmIngredients<'_>) -> Self {
Self::ScalarGlm {
c_array: ingredients.c_array.clone(),
d_array: ingredients.d_array.cloned(),
x: ingredients.x.clone(),
}
}
}
pub struct GaussianDerivatives;
impl HessianDerivativeProvider for GaussianDerivatives {
fn outer_hessian_derivative_kernel(&self) -> Option<OuterHessianDerivativeKernel> {
Some(OuterHessianDerivativeKernel::Gaussian)
}
fn hessian_derivative_correction(
&self,
arr: &Array1<f64>,
) -> Result<Option<Array2<f64>>, String> {
assert!(arr.iter().all(|v| !v.is_nan()));
Ok(None)
}
fn has_corrections(&self) -> bool {
false
}
}
pub struct SinglePredictorGlmDerivatives {
pub c_array: Array1<f64>,
pub d_array: Option<Array1<f64>>,
pub hessian_weights: Array1<f64>,
pub x_transformed: DesignMatrix,
}
impl HessianDerivativeProvider for SinglePredictorGlmDerivatives {
fn hessian_derivative_correction(
&self,
v_k: &Array1<f64>,
) -> Result<Option<Array2<f64>>, String> {
let x_v = self.x_transformed.matrixvectormultiply(v_k);
let crate::pirls::DirectionalWorkingCurvature::Diagonal(mut neg_c_xv) =
crate::pirls::directionalworking_curvature_from_c_array(
&self.c_array,
&self.hessian_weights,
&x_v,
);
neg_c_xv.mapv_inplace(|value| -value);
let result = self
.x_transformed
.xt_diag_x_signed_op(SignedWeightsView::from_array(&neg_c_xv))
.map_err(|e| format!("hessian_derivative_correction xtwx: {e}"))?;
Ok(Some(result))
}
fn hessian_derivative_correction_result(
&self,
v_k: &Array1<f64>,
) -> Result<Option<DriftDerivResult>, String> {
let x_v = self.x_transformed.matrixvectormultiply(v_k);
let crate::pirls::DirectionalWorkingCurvature::Diagonal(mut neg_c_xv) =
crate::pirls::directionalworking_curvature_from_c_array(
&self.c_array,
&self.hessian_weights,
&x_v,
);
neg_c_xv.mapv_inplace(|value| -value);
Ok(Some(DriftDerivResult::Operator(Arc::new(
GlmCurvatureCorrectionOperator {
x_design: self.x_transformed.clone(),
neg_c_xv,
p: self.x_transformed.ncols(),
},
))))
}
fn hessian_second_derivative_correction(
&self,
v_k: &Array1<f64>,
v_l: &Array1<f64>,
u_kl: &Array1<f64>,
) -> Result<Option<Array2<f64>>, String> {
let x_vk = self.x_transformed.matrixvectormultiply(v_k);
let x_vl = self.x_transformed.matrixvectormultiply(v_l);
let x_ukl = self.x_transformed.matrixvectormultiply(u_kl);
let n = self.x_transformed.nrows();
let mut weights = Array1::zeros(n);
let crate::pirls::DirectionalWorkingCurvature::Diagonal(first_weights) =
crate::pirls::directionalworking_curvature_from_c_array(
&self.c_array,
&self.hessian_weights,
&x_ukl,
);
weights.assign(&first_weights);
if let Some(ref d_array) = self.d_array {
Zip::from(&mut weights)
.and(d_array)
.and(&x_vk)
.and(&x_vl)
.and(&self.hessian_weights)
.par_for_each(|w, &d, &xvk, &xvl, &h| {
if h > 0.0 {
let delta = d * xvk * xvl;
if delta.is_finite() {
*w += delta;
}
}
});
}
let result = self
.x_transformed
.xt_diag_x_signed_op(SignedWeightsView::from_array(&weights))
.map_err(|e| format!("hessian_second_derivative_correction xtwx: {e}"))?;
Ok(Some(result))
}
fn has_corrections(&self) -> bool {
true
}
fn scalar_glm_ingredients(&self) -> Option<ScalarGlmIngredients<'_>> {
Some(ScalarGlmIngredients {
c_array: &self.c_array,
d_array: self.d_array.as_ref(),
x: &self.x_transformed,
})
}
}
pub struct FirthAwareGlmDerivatives {
pub(super) base: SinglePredictorGlmDerivatives,
pub(super) firth_op: std::sync::Arc<super::FirthDenseOperator>,
}
impl HessianDerivativeProvider for FirthAwareGlmDerivatives {
fn hessian_derivative_correction(
&self,
v_k: &Array1<f64>,
) -> Result<Option<Array2<f64>>, String> {
let base_corr = self.base.hessian_derivative_correction(v_k)?;
let deta_k: Array1<f64> =
crate::faer_ndarray::fast_av(&self.firth_op.x_dense, v_k).mapv(|v| -v);
let dir_k = self.firth_op.direction_from_deta(deta_k);
let firth_corr = self.firth_op.hphi_direction(&dir_k);
match base_corr {
Some(mut bc) => {
bc -= &firth_corr;
Ok(Some(bc))
}
None => Ok(Some(-firth_corr)),
}
}
fn hessian_second_derivative_correction(
&self,
v_k: &Array1<f64>,
v_l: &Array1<f64>,
u_kl: &Array1<f64>,
) -> Result<Option<Array2<f64>>, String> {
let base_corr = self
.base
.hessian_second_derivative_correction(v_k, v_l, u_kl)?;
let deta_kl: Array1<f64> = crate::faer_ndarray::fast_av(&self.firth_op.x_dense, u_kl);
let dir_kl = self.firth_op.direction_from_deta(deta_kl);
let firth_first = self.firth_op.hphi_direction(&dir_kl);
let deta_k: Array1<f64> =
crate::faer_ndarray::fast_av(&self.firth_op.x_dense, v_k).mapv(|v| -v);
let dir_k = self.firth_op.direction_from_deta(deta_k);
let deta_l: Array1<f64> =
crate::faer_ndarray::fast_av(&self.firth_op.x_dense, v_l).mapv(|v| -v);
let dir_l = self.firth_op.direction_from_deta(deta_l);
let p = v_k.len();
let eye = Array2::<f64>::eye(p);
let firth_second = self
.firth_op
.hphisecond_direction_apply(&dir_k, &dir_l, &eye);
let mut result = match base_corr {
Some(bc) => bc,
None => Array2::zeros((p, p)),
};
result -= &firth_first;
result -= &firth_second;
Ok(Some(result))
}
fn hessian_derivative_correction_result(
&self,
v_k: &Array1<f64>,
) -> Result<Option<DriftDerivResult>, String> {
let base = self.base.hessian_derivative_correction_result(v_k)?;
let deta_k: Array1<f64> =
crate::faer_ndarray::fast_av(&self.firth_op.x_dense, v_k).mapv(|v| -v);
let dir_k = self.firth_op.direction_from_deta(deta_k);
let neg_firth_corr = -self.firth_op.hphi_direction(&dir_k);
match base {
Some(DriftDerivResult::Operator(operator)) => Ok(Some(DriftDerivResult::Operator(
Arc::new(CompositeHyperOperator {
dense: Some(neg_firth_corr),
operators: vec![operator],
dim_hint: self.base.x_transformed.ncols(),
}),
))),
Some(DriftDerivResult::Dense(mut dense)) => {
dense += &neg_firth_corr;
Ok(Some(DriftDerivResult::Dense(dense)))
}
None => Ok(Some(DriftDerivResult::Dense(neg_firth_corr))),
}
}
fn has_corrections(&self) -> bool {
true
}
fn scalar_glm_ingredients(&self) -> Option<ScalarGlmIngredients<'_>> {
None
}
}
#[derive(Clone)]
pub struct ExactJeffreysTerm {
operator: Option<std::sync::Arc<super::FirthDenseOperator>>,
value_override: Option<f64>,
}
impl ExactJeffreysTerm {
pub(crate) fn new(operator: std::sync::Arc<super::FirthDenseOperator>) -> Self {
Self {
operator: Some(operator),
value_override: None,
}
}
pub(crate) fn value_only(phi: f64) -> Self {
Self {
operator: None,
value_override: Some(phi),
}
}
pub(crate) fn with_projected_value(
operator: std::sync::Arc<super::FirthDenseOperator>,
projected_value: f64,
) -> Self {
Self {
operator: Some(operator),
value_override: Some(projected_value),
}
}
#[inline]
pub(crate) fn value(&self) -> f64 {
self.value_override.unwrap_or_else(|| {
self.operator
.as_ref()
.map_or(0.0, |operator| operator.jeffreys_logdet())
})
}
#[inline]
pub(crate) fn operator_arc(&self) -> Option<std::sync::Arc<super::FirthDenseOperator>> {
self.operator.as_ref().map(std::sync::Arc::clone)
}
}
pub(crate) struct GuardedCorrection {
value: f64,
gradient: Option<Array1<f64>>,
include: bool,
}
impl GuardedCorrection {
pub(crate) fn new(value: f64, gradient: Option<Array1<f64>>, include: bool) -> Self {
Self {
value,
gradient,
include,
}
}
pub(crate) fn apply_value(&self, cost: &mut f64) {
if self.include {
*cost += self.value;
}
}
pub(crate) fn apply_gradient(&self, rho_grad: &mut Array1<f64>) {
if !self.include {
return;
}
if let Some(grad) = self.gradient.as_ref() {
let k = grad.len();
let mut sl = rho_grad.slice_mut(ndarray::s![..k]);
sl += grad;
}
}
}
#[derive(Clone, Debug)]
pub struct BarrierConfig {
pub tau: f64,
pub constrained_indices: Vec<usize>,
pub lower_bounds: Vec<f64>,
pub bound_signs: Vec<f64>,
}
impl BarrierConfig {
pub fn from_constraints(
constraints: Option<&crate::pirls::LinearInequalityConstraints>,
) -> Option<Self> {
const SIMPLE_BOUND_ENTRY_TOL: f64 = 1e-14;
const DEFAULT_BARRIER_TAU: f64 = 1e-6;
let constraints = constraints?;
let mut indices = Vec::new();
let mut lower_bounds = Vec::new();
let mut bound_signs = Vec::new();
for i in 0..constraints.a.nrows() {
let row = constraints.a.row(i);
let mut single_col = None;
let mut single_sign = 0.0_f64;
let mut is_simple = true;
for (j, &val) in row.iter().enumerate() {
if val.abs() < SIMPLE_BOUND_ENTRY_TOL {
continue;
}
if ((val - 1.0).abs() < SIMPLE_BOUND_ENTRY_TOL
|| (val + 1.0).abs() < SIMPLE_BOUND_ENTRY_TOL)
&& single_col.is_none()
{
single_col = Some(j);
single_sign = if val > 0.0 { 1.0 } else { -1.0 };
} else {
is_simple = false;
break;
}
}
if is_simple && let Some(col) = single_col {
indices.push(col);
lower_bounds.push(constraints.b[i]);
bound_signs.push(single_sign);
}
}
if indices.is_empty() {
return None;
}
Some(BarrierConfig {
tau: DEFAULT_BARRIER_TAU,
constrained_indices: indices,
lower_bounds,
bound_signs,
})
}
pub fn slacks(&self, beta: &Array1<f64>) -> Option<Vec<f64>> {
let mut slacks = Vec::with_capacity(self.constrained_indices.len());
for (ci, &idx) in self.constrained_indices.iter().enumerate() {
let sign = self.bound_signs[ci];
let delta = sign * beta[idx] - self.lower_bounds[ci];
if delta <= 0.0 {
return None;
}
slacks.push(delta);
}
Some(slacks)
}
pub fn add_barrier_hessian_diagonal(
&self,
h: &mut Array2<f64>,
beta: &Array1<f64>,
) -> Result<(), String> {
let slacks = self
.slacks(beta)
.ok_or_else(|| "Barrier: infeasible point (slack ≤ 0)".to_string())?;
for (ci, &idx) in self.constrained_indices.iter().enumerate() {
h[[idx, idx]] += self.tau / (slacks[ci] * slacks[ci]);
}
Ok(())
}
pub fn barrier_cost(&self, beta: &Array1<f64>) -> f64 {
let mut total = 0.0_f64;
for (ci, &idx) in self.constrained_indices.iter().enumerate() {
let sign = self.bound_signs[ci];
let delta = sign * beta[idx] - self.lower_bounds[ci];
if delta <= 0.0 {
return f64::INFINITY;
}
total += delta.ln();
}
-self.tau * total
}
pub fn barrier_curvature_locally_concentrated(
&self,
beta: &Array1<f64>,
ratio: f64,
saturation_threshold: f64,
) -> bool {
let Some(mut slacks) = self.slacks(beta) else {
return true; };
if slacks.is_empty() {
return false;
}
let min_slack = slacks.iter().copied().fold(f64::INFINITY, f64::min);
if min_slack > 0.0 && min_slack.is_finite() && saturation_threshold.is_finite() {
let max_barrier_curv = self.tau / (min_slack * min_slack);
if max_barrier_curv >= saturation_threshold {
return true;
}
}
slacks.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
let median = if slacks.len() % 2 == 1 {
slacks[slacks.len() / 2]
} else {
let mid = slacks.len() / 2;
0.5 * (slacks[mid - 1] + slacks[mid])
};
if !median.is_finite() || median <= 0.0 {
return true;
}
min_slack < ratio * median
}
pub fn barrier_curvature_is_significant(
&self,
beta: &Array1<f64>,
ref_diag: f64,
threshold: f64,
) -> bool {
let Some(slacks) = self.slacks(beta) else {
return true; };
let max_barrier_curv = slacks
.iter()
.map(|&d| self.tau / (d * d))
.fold(0.0_f64, f64::max);
max_barrier_curv > threshold * ref_diag
}
}
pub struct BarrierDerivativeProvider<'a> {
inner: &'a dyn HessianDerivativeProvider,
tau: f64,
constrained_indices: &'a [usize],
bound_signs: &'a [f64],
slacks: Vec<f64>,
p: usize,
}
impl<'a> BarrierDerivativeProvider<'a> {
pub fn new(
inner: &'a dyn HessianDerivativeProvider,
config: &'a BarrierConfig,
beta: &Array1<f64>,
) -> Result<Self, String> {
let slacks = config
.slacks(beta)
.ok_or_else(|| "BarrierDerivativeProvider: infeasible point".to_string())?;
Ok(Self {
inner,
tau: config.tau,
constrained_indices: &config.constrained_indices,
bound_signs: &config.bound_signs,
slacks,
p: beta.len(),
})
}
fn barrier_correction(&self, u: &Array1<f64>) -> Array2<f64> {
let mut result = Array2::zeros((self.p, self.p));
for (ci, &idx) in self.constrained_indices.iter().enumerate() {
let inv_cube = 1.0 / (self.slacks[ci].powi(3));
result[[idx, idx]] = -2.0 * self.tau * self.bound_signs[ci] * u[idx] * inv_cube;
}
result
}
fn barrier_second_correction(&self, u: &Array1<f64>, v: &Array1<f64>) -> Array2<f64> {
let mut result = Array2::zeros((self.p, self.p));
for (ci, &idx) in self.constrained_indices.iter().enumerate() {
let inv_4 = 1.0 / (self.slacks[ci].powi(4));
result[[idx, idx]] = 6.0 * self.tau * u[idx] * v[idx] * inv_4;
}
result
}
}
impl HessianDerivativeProvider for BarrierDerivativeProvider<'_> {
fn hessian_derivative_correction(
&self,
v_k: &Array1<f64>,
) -> Result<Option<Array2<f64>>, String> {
let neg_v_k = v_k.mapv(|x| -x);
let barrier_corr = self.barrier_correction(&neg_v_k);
match self.inner.hessian_derivative_correction(v_k)? {
Some(mut ic) => {
ic += &barrier_corr;
Ok(Some(ic))
}
None => Ok(Some(barrier_corr)),
}
}
fn hessian_derivative_correction_result(
&self,
v_k: &Array1<f64>,
) -> Result<Option<DriftDerivResult>, String> {
let neg_v_k = v_k.mapv(|x| -x);
let barrier_corr = self.barrier_correction(&neg_v_k);
match self.inner.hessian_derivative_correction_result(v_k)? {
Some(DriftDerivResult::Dense(mut dense)) => {
dense += &barrier_corr;
Ok(Some(DriftDerivResult::Dense(dense)))
}
Some(DriftDerivResult::Operator(operator)) => Ok(Some(DriftDerivResult::Operator(
Arc::new(CompositeHyperOperator {
dense: Some(barrier_corr),
operators: vec![operator],
dim_hint: self.p,
}),
))),
None => Ok(Some(DriftDerivResult::Dense(barrier_corr))),
}
}
fn hessian_second_derivative_correction(
&self,
v_k: &Array1<f64>,
v_l: &Array1<f64>,
u_kl: &Array1<f64>,
) -> Result<Option<Array2<f64>>, String> {
let barrier_total =
&self.barrier_correction(u_kl) + &self.barrier_second_correction(v_k, v_l);
match self
.inner
.hessian_second_derivative_correction(v_k, v_l, u_kl)?
{
Some(mut ic) => {
ic += &barrier_total;
Ok(Some(ic))
}
None => Ok(Some(barrier_total)),
}
}
fn hessian_second_derivative_correction_result(
&self,
v_k: &Array1<f64>,
v_l: &Array1<f64>,
u_kl: &Array1<f64>,
) -> Result<Option<DriftDerivResult>, String> {
let barrier_total =
&self.barrier_correction(u_kl) + &self.barrier_second_correction(v_k, v_l);
match self
.inner
.hessian_second_derivative_correction_result(v_k, v_l, u_kl)?
{
Some(DriftDerivResult::Dense(mut dense)) => {
dense += &barrier_total;
Ok(Some(DriftDerivResult::Dense(dense)))
}
Some(DriftDerivResult::Operator(operator)) => Ok(Some(DriftDerivResult::Operator(
Arc::new(CompositeHyperOperator {
dense: Some(barrier_total),
operators: vec![operator],
dim_hint: self.p,
}),
))),
None => Ok(Some(DriftDerivResult::Dense(barrier_total))),
}
}
fn has_corrections(&self) -> bool {
true
}
fn scalar_glm_ingredients(&self) -> Option<ScalarGlmIngredients<'_>> {
None
}
}
#[derive(Clone)]
pub struct HyperCoord {
pub a: f64,
pub g: Array1<f64>,
pub drift: HyperCoordDrift,
pub ld_s: f64,
pub b_depends_on_beta: bool,
pub is_penalty_like: bool,
pub firth_g: Option<Array1<f64>>,
pub tk_eta_fixed: Option<Array1<f64>>,
pub tk_x_fixed: Option<Array2<f64>>,
}
pub struct HyperCoordPair {
pub a: f64,
pub g: Array1<f64>,
pub b_mat: Array2<f64>,
pub b_operator: Option<Box<dyn HyperOperator>>,
pub ld_s: f64,
}
impl HyperCoordPair {
pub fn zero() -> Self {
Self {
a: 0.0,
g: Array1::zeros(0),
b_mat: Array2::zeros((0, 0)),
b_operator: None,
ld_s: 0.0,
}
}
}
#[derive(Clone)]
pub enum DriftDerivResult {
Dense(Array2<f64>),
Operator(Arc<dyn HyperOperator>),
}
impl std::fmt::Debug for DriftDerivResult {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Dense(matrix) => f
.debug_tuple("Dense")
.field(&format_args!("{}x{}", matrix.nrows(), matrix.ncols()))
.finish(),
Self::Operator(_) => f
.debug_tuple("Operator")
.field(&"<hyper-operator>")
.finish(),
}
}
}
impl DriftDerivResult {
pub fn into_operator(self) -> Arc<dyn HyperOperator> {
match self {
Self::Dense(matrix) => Arc::new(DenseMatrixHyperOperator { matrix }),
Self::Operator(operator) => operator,
}
}
pub fn trace_logdet(&self, hop: &dyn HessianOperator) -> f64 {
match self {
Self::Dense(matrix) => hop.trace_logdet_gradient(matrix),
Self::Operator(operator) => hop.trace_logdet_operator(operator.as_ref()),
}
}
pub fn apply(&self, v: &Array1<f64>) -> Array1<f64> {
match self {
Self::Dense(matrix) => matrix.dot(v),
Self::Operator(operator) => operator.mul_vec(v),
}
}
pub fn trace_logdet_hessian_cross(&self, rhs: &Self, hop: &dyn HessianOperator) -> f64 {
match (self, rhs) {
(Self::Dense(left), Self::Dense(right)) => hop.trace_logdet_hessian_cross(left, right),
(Self::Dense(left), Self::Operator(right)) => {
hop.trace_logdet_hessian_cross_matrix_operator(left, right.as_ref())
}
(Self::Operator(left), Self::Dense(right)) => {
hop.trace_logdet_hessian_cross_matrix_operator(right, left.as_ref())
}
(Self::Operator(left), Self::Operator(right)) => {
hop.trace_logdet_hessian_cross_operator(left.as_ref(), right.as_ref())
}
}
}
}
pub type FixedDriftDerivFn =
Box<dyn Fn(usize, &Array1<f64>) -> Option<DriftDerivResult> + Send + Sync>;
pub struct ContractedPsiSecondOrder {
pub objective: Array1<f64>,
pub score: Array2<f64>,
pub hessian: Vec<DriftDerivResult>,
pub ld_s: Array1<f64>,
}
pub type ContractedPsiSecondOrderFn =
Arc<dyn Fn(&[f64]) -> Result<Option<ContractedPsiSecondOrder>, String> + Send + Sync>;
pub trait HyperOperator: Send + Sync {
fn dim(&self) -> usize;
fn mul_vec(&self, v: &Array1<f64>) -> Array1<f64>;
fn mul_vec_view(&self, v: ArrayView1<'_, f64>) -> Array1<f64> {
self.mul_vec(&v.to_owned())
}
fn mul_vec_into(&self, v: ArrayView1<'_, f64>, mut out: ArrayViewMut1<'_, f64>) {
out.assign(&self.mul_vec_view(v));
}
fn mul_mat(&self, factor: &Array2<f64>) -> Array2<f64> {
use rayon::iter::{IntoParallelIterator, ParallelIterator};
let p = factor.nrows();
let k = factor.ncols();
let mut out = Array2::<f64>::zeros((p, k));
if rayon::current_thread_index().is_some() {
for col in 0..k {
let bv = out.column_mut(col);
self.mul_vec_into(factor.column(col), bv);
}
return out;
}
let cols: Vec<Array1<f64>> = (0..k)
.into_par_iter()
.map(|col| {
let mut bv = Array1::<f64>::zeros(p);
self.mul_vec_into(factor.column(col), bv.view_mut());
bv
})
.collect();
for (col, bv) in cols.into_iter().enumerate() {
out.column_mut(col).assign(&bv);
}
out
}
fn trace_projected_factor(&self, factor: &Array2<f64>) -> f64 {
let op_factor = self.mul_mat(factor);
factor
.iter()
.zip(op_factor.iter())
.map(|(&f, &bf)| f * bf)
.sum()
}
fn trace_projected_factor_cached(
&self,
factor: &Array2<f64>,
factor_cache: &ProjectedFactorCache,
) -> f64 {
assert!(std::mem::size_of_val(factor_cache) > 0);
self.trace_projected_factor(factor)
}
fn projected_matrix(&self, factor: &Array2<f64>) -> Array2<f64> {
let op_factor = self.mul_mat(factor);
crate::faer_ndarray::fast_atb(factor, &op_factor)
}
fn projected_matrix_cached(
&self,
factor: &Array2<f64>,
factor_cache: &ProjectedFactorCache,
) -> Array2<f64> {
assert!(std::mem::size_of_val(factor_cache) > 0);
self.projected_matrix(factor)
}
fn mul_basis_columns_into(&self, start: usize, mut out: ArrayViewMut2<'_, f64>) {
let cols = out.ncols();
let dim = out.nrows();
assert!(start + cols <= dim);
let mut basis = Array1::<f64>::zeros(dim);
for local_col in 0..cols {
let global_col = start + local_col;
basis[global_col] = 1.0;
self.mul_vec_into(basis.view(), out.column_mut(local_col));
basis[global_col] = 0.0;
}
}
fn scaled_add_mul_vec(
&self,
v: ArrayView1<'_, f64>,
scale: f64,
mut out: ArrayViewMut1<'_, f64>,
) {
if scale == 0.0 {
return;
}
let mut work = Array1::<f64>::zeros(out.len());
self.mul_vec_into(v, work.view_mut());
out.scaled_add(scale, &work);
}
fn bilinear(&self, v: &Array1<f64>, u: &Array1<f64>) -> f64 {
let mut bv = Array1::<f64>::zeros(v.len());
self.mul_vec_into(v.view(), bv.view_mut());
u.dot(&bv)
}
fn bilinear_view(&self, v: ArrayView1<'_, f64>, u: ArrayView1<'_, f64>) -> f64 {
let mut bv = Array1::<f64>::zeros(v.len());
self.mul_vec_into(v, bv.view_mut());
u.dot(&bv)
}
fn has_fast_bilinear_view(&self) -> bool {
false
}
fn to_dense(&self) -> Array2<f64> {
let p = self.dim();
let mut out = Array2::<f64>::zeros((p, p));
let mut basis = Array1::<f64>::zeros(p);
for j in 0..p {
basis[j] = 1.0;
self.mul_vec_into(basis.view(), out.column_mut(j));
basis[j] = 0.0;
}
out
}
fn is_implicit(&self) -> bool;
fn as_implicit(&self) -> Option<&ImplicitHyperOperator> {
None
}
fn as_composite(&self) -> Option<&CompositeHyperOperator> {
None
}
fn as_weighted(&self) -> Option<&WeightedHyperOperator> {
None
}
fn block_local_data(&self) -> Option<(&Array2<f64>, usize, usize)> {
None
}
fn as_sparse_directional(&self) -> Option<&SparseDirectionalHyperOperator> {
None
}
}
#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)]
pub struct ProjectedFactorKey {
design_id: usize,
factor_ptr: usize,
rows: usize,
cols: usize,
row_stride: isize,
col_stride: isize,
value_hash: u64,
value_hash2: u64,
}
impl ProjectedFactorKey {
pub fn from_factor_view(design_id: usize, factor: ArrayView2<'_, f64>) -> Self {
let strides = factor.strides();
let (value_hash, value_hash2) = projected_factor_value_fingerprint(factor);
Self {
design_id,
factor_ptr: factor.as_ptr() as usize,
rows: factor.nrows(),
cols: factor.ncols(),
row_stride: strides[0],
col_stride: strides[1],
value_hash,
value_hash2,
}
}
}
fn projected_factor_value_fingerprint(factor: ArrayView2<'_, f64>) -> (u64, u64) {
let mut h1 = 0xcbf2_9ce4_8422_2325_u64;
let mut h2 = 0x9e37_79b1_85eb_ca87_u64;
for (idx, value) in factor.iter().enumerate() {
let bits = value.to_bits();
let mixed = bits.wrapping_add((idx as u64).wrapping_mul(0x517c_c1b7_2722_0a95));
h1 ^= mixed;
h1 = h1.wrapping_mul(0x0000_0100_0000_01b3);
h2 ^= bits.rotate_left((idx & 63) as u32);
h2 = h2.wrapping_mul(0x94d0_49bb_1331_11eb).rotate_left(27);
}
(h1, h2)
}
pub struct ProjectedFactorCache {
inner: Mutex<ProjectedFactorCacheInner>,
}
struct ProjectedFactorCacheInner {
entries: HashMap<ProjectedFactorKey, ProjectedFactorEntry>,
in_progress: HashMap<ProjectedFactorKey, Arc<ProjectedFactorInProgress>>,
next_seq: u64,
total_bytes: usize,
budget_bytes: usize,
}
struct ProjectedFactorInProgress {
state: Mutex<Option<ProjectedFactorInProgressState>>,
ready: Condvar,
waiter_count: std::sync::atomic::AtomicUsize,
subscriber_arrived: (Mutex<()>, Condvar),
}
enum ProjectedFactorInProgressState {
Ready(Arc<Array2<f64>>),
Failed,
}
struct ProjectedFactorEntry {
value: Arc<Array2<f64>>,
bytes: usize,
last_used: u64,
}
impl Default for ProjectedFactorCache {
fn default() -> Self {
Self::with_budget(Self::DEFAULT_BUDGET_BYTES)
}
}
impl ProjectedFactorCache {
pub const DEFAULT_BUDGET_BYTES: usize = 2 * 1024 * 1024 * 1024;
pub fn with_budget(budget_bytes: usize) -> Self {
Self {
inner: Mutex::new(ProjectedFactorCacheInner {
entries: HashMap::new(),
in_progress: HashMap::new(),
next_seq: 0,
total_bytes: 0,
budget_bytes,
}),
}
}
pub fn get_or_insert_with(
&self,
key: ProjectedFactorKey,
compute: impl FnOnce() -> Array2<f64>,
) -> Arc<Array2<f64>> {
enum CacheLookup {
Hit(Arc<Array2<f64>>),
Wait(Arc<ProjectedFactorInProgress>),
Compute(Arc<ProjectedFactorInProgress>),
}
let lookup = {
let mut inner = self
.inner
.lock()
.expect("projected factor cache lock poisoned");
inner.next_seq += 1;
let now = inner.next_seq;
if let Some(entry) = inner.entries.get_mut(&key) {
entry.last_used = now;
CacheLookup::Hit(entry.value.clone())
} else if let Some(waiter) = inner.in_progress.get(&key) {
CacheLookup::Wait(waiter.clone())
} else {
let marker = Arc::new(ProjectedFactorInProgress {
state: Mutex::new(None),
ready: Condvar::new(),
waiter_count: std::sync::atomic::AtomicUsize::new(0),
subscriber_arrived: (Mutex::new(()), Condvar::new()),
});
inner.in_progress.insert(key, marker.clone());
CacheLookup::Compute(marker)
}
};
match lookup {
CacheLookup::Hit(value) => value,
CacheLookup::Wait(marker) => {
marker
.waiter_count
.fetch_add(1, std::sync::atomic::Ordering::AcqRel);
let (lock, cv) = &marker.subscriber_arrived;
drop(
lock.lock()
.expect("subscriber-arrived notification lock poisoned"),
);
cv.notify_all();
let mut guard = marker
.state
.lock()
.expect("projected factor in-progress lock poisoned");
let result = loop {
match guard.as_ref() {
Some(ProjectedFactorInProgressState::Ready(value)) => {
break value.clone();
}
Some(ProjectedFactorInProgressState::Failed) => {
marker
.waiter_count
.fetch_sub(1, std::sync::atomic::Ordering::AcqRel);
reml_contract_panic("projected factor cache producer panicked")
}
None => {
guard = marker
.ready
.wait(guard)
.expect("projected factor in-progress wait poisoned");
}
}
};
marker
.waiter_count
.fetch_sub(1, std::sync::atomic::Ordering::AcqRel);
result
}
CacheLookup::Compute(marker) => {
let computed = match catch_unwind(AssertUnwindSafe(|| Arc::new(compute()))) {
Ok(value) => value,
Err(payload) => {
let mut inner = self
.inner
.lock()
.expect("projected factor cache lock poisoned");
inner.in_progress.remove(&key);
drop(inner);
let mut guard = marker
.state
.lock()
.expect("projected factor in-progress lock poisoned");
*guard = Some(ProjectedFactorInProgressState::Failed);
marker.ready.notify_all();
resume_unwind(payload);
}
};
let bytes = computed.len().saturating_mul(std::mem::size_of::<f64>());
let mut inner = self
.inner
.lock()
.expect("projected factor cache lock poisoned");
inner.next_seq += 1;
let now = inner.next_seq;
if inner.budget_bytes > 0 && bytes <= inner.budget_bytes {
while inner.total_bytes.saturating_add(bytes) > inner.budget_bytes
&& !inner.entries.is_empty()
{
let Some(oldest_key) = inner
.entries
.iter()
.min_by_key(|(_, e)| e.last_used)
.map(|(k, _)| *k)
else {
break;
};
if let Some(removed) = inner.entries.remove(&oldest_key) {
inner.total_bytes = inner.total_bytes.saturating_sub(removed.bytes);
}
}
}
let value = if let Some(entry) = inner.entries.get_mut(&key) {
entry.last_used = now;
entry.value.clone()
} else {
inner.entries.insert(
key,
ProjectedFactorEntry {
value: computed.clone(),
bytes,
last_used: now,
},
);
inner.total_bytes = inner.total_bytes.saturating_add(bytes);
computed
};
inner.in_progress.remove(&key);
drop(inner);
let mut guard = marker
.state
.lock()
.expect("projected factor in-progress lock poisoned");
*guard = Some(ProjectedFactorInProgressState::Ready(value.clone()));
marker.ready.notify_all();
value
}
}
}
pub fn len(&self) -> usize {
self.inner
.lock()
.map(|inner| inner.entries.len())
.unwrap_or(0)
}
pub fn total_bytes(&self) -> usize {
self.inner
.lock()
.map(|inner| inner.total_bytes)
.unwrap_or(0)
}
pub fn is_empty(&self) -> bool {
self.len() == 0
}
}
#[derive(Clone)]
pub struct DenseMatrixHyperOperator {
pub matrix: Array2<f64>,
}
impl HyperOperator for DenseMatrixHyperOperator {
fn dim(&self) -> usize {
self.matrix.nrows()
}
fn mul_vec(&self, v: &Array1<f64>) -> Array1<f64> {
self.matrix.dot(v)
}
fn mul_vec_view(&self, v: ArrayView1<'_, f64>) -> Array1<f64> {
self.matrix.dot(&v)
}
fn mul_vec_into(&self, v: ArrayView1<'_, f64>, out: ArrayViewMut1<'_, f64>) {
dense_matvec_into(&self.matrix, v, out);
}
fn mul_basis_columns_into(&self, start: usize, mut out: ArrayViewMut2<'_, f64>) {
let end = start + out.ncols();
assert!(end <= self.matrix.ncols());
out.assign(&self.matrix.slice(ndarray::s![.., start..end]));
}
fn scaled_add_mul_vec(&self, v: ArrayView1<'_, f64>, scale: f64, out: ArrayViewMut1<'_, f64>) {
dense_matvec_scaled_add_into(&self.matrix, v, scale, out);
}
fn bilinear(&self, v: &Array1<f64>, u: &Array1<f64>) -> f64 {
dense_bilinear(&self.matrix, v.view(), u.view())
}
fn bilinear_view(&self, v: ArrayView1<'_, f64>, u: ArrayView1<'_, f64>) -> f64 {
dense_bilinear(&self.matrix, v, u)
}
fn to_dense(&self) -> Array2<f64> {
self.matrix.clone()
}
fn is_implicit(&self) -> bool {
false
}
}
#[derive(Clone)]
pub struct CompositeHyperOperator {
pub dense: Option<Array2<f64>>,
pub operators: Vec<Arc<dyn HyperOperator>>,
pub dim_hint: usize,
}
fn composite_trace_implicit_batched(
operators: &[Arc<dyn HyperOperator>],
factor: &Array2<f64>,
cache: Option<&ProjectedFactorCache>,
) -> f64 {
let mut trace = 0.0;
let mut group_starts: Vec<Vec<usize>> = Vec::new();
let mut handled = vec![false; operators.len()];
for (i, op) in operators.iter().enumerate() {
if handled[i] {
continue;
}
let Some(impl_i) = op.as_implicit() else {
continue;
};
let mut group = vec![i];
handled[i] = true;
for j in (i + 1)..operators.len() {
if handled[j] {
continue;
}
if let Some(impl_j) = operators[j].as_implicit()
&& Arc::ptr_eq(&impl_i.implicit_deriv, &impl_j.implicit_deriv)
&& Arc::ptr_eq(&impl_i.x_design, &impl_j.x_design)
&& Arc::ptr_eq(impl_i.w_diag.as_arc(), impl_j.w_diag.as_arc())
&& impl_i.p == impl_j.p
{
group.push(j);
handled[j] = true;
}
}
group_starts.push(group);
}
for group in &group_starts {
if group.len() >= 2 {
let lead = operators[group[0]].as_implicit().unwrap();
let xf = match cache {
Some(c) => lead.cached_xf(factor, c),
None => Arc::new(lead.compute_xf(factor)),
};
let axes: Vec<(usize, &Array2<f64>, Option<&Array1<f64>>)> = group
.iter()
.map(|&k| {
let op = operators[k].as_implicit().unwrap();
(op.axis, &op.s_psi, op.c_x_psi_beta.as_deref())
})
.collect();
let values = lead.trace_projected_factor_all_axes_with_xf(factor, xf.view(), &axes);
trace += values.iter().sum::<f64>();
} else {
let op = &operators[group[0]];
trace += match cache {
Some(c) => op.trace_projected_factor_cached(factor, c),
None => op.trace_projected_factor(factor),
};
}
}
for (i, op) in operators.iter().enumerate() {
if handled[i] {
continue;
}
trace += match cache {
Some(c) => op.trace_projected_factor_cached(factor, c),
None => op.trace_projected_factor(factor),
};
}
trace
}
fn trace_projected_factors_batched(
operators: &[Arc<dyn HyperOperator>],
factor: &Array2<f64>,
cache: &ProjectedFactorCache,
) -> Vec<f64> {
let mut out = vec![0.0; operators.len()];
let mut handled = vec![false; operators.len()];
for i in 0..operators.len() {
if handled[i] {
continue;
}
let Some(impl_i) = operators[i].as_implicit() else {
out[i] = operators[i].trace_projected_factor_cached(factor, cache);
handled[i] = true;
continue;
};
let mut group = vec![i];
handled[i] = true;
for j in (i + 1)..operators.len() {
if handled[j] {
continue;
}
if let Some(impl_j) = operators[j].as_implicit()
&& Arc::ptr_eq(&impl_i.implicit_deriv, &impl_j.implicit_deriv)
&& Arc::ptr_eq(&impl_i.x_design, &impl_j.x_design)
&& Arc::ptr_eq(impl_i.w_diag.as_arc(), impl_j.w_diag.as_arc())
&& impl_i.p == impl_j.p
{
group.push(j);
handled[j] = true;
}
}
if group.len() >= 2 {
let xf = impl_i.cached_xf(factor, cache);
let axes: Vec<(usize, &Array2<f64>, Option<&Array1<f64>>)> = group
.iter()
.map(|&idx| {
let op = operators[idx].as_implicit().unwrap();
(op.axis, &op.s_psi, op.c_x_psi_beta.as_deref())
})
.collect();
let values = impl_i.trace_projected_factor_all_axes_with_xf(factor, xf.view(), &axes);
for (&idx, value) in group.iter().zip(values) {
out[idx] = value;
}
} else {
out[i] = operators[i].trace_projected_factor_cached(factor, cache);
}
}
out
}
fn collect_projected_trace_terms<'a>(
out_idx: usize,
weight: f64,
op: &'a dyn HyperOperator,
factor: &Array2<f64>,
dense_acc: &mut [f64],
terms: &mut Vec<(usize, f64, &'a dyn HyperOperator)>,
) {
if weight == 0.0 {
return;
}
if let Some(composite) = op.as_composite() {
if let Some(dense) = composite.dense.as_ref() {
dense_acc[out_idx] += weight * dense_trace_projected_factor(dense, factor);
}
for inner in &composite.operators {
collect_projected_trace_terms(
out_idx,
weight,
inner.as_ref(),
factor,
dense_acc,
terms,
);
}
} else if let Some(weighted) = op.as_weighted() {
for (term_weight, inner) in &weighted.terms {
collect_projected_trace_terms(
out_idx,
weight * *term_weight,
inner.as_ref(),
factor,
dense_acc,
terms,
);
}
} else {
terms.push((out_idx, weight, op));
}
}
fn collect_projected_matrix_terms<'a>(
out_idx: usize,
weight: f64,
op: &'a dyn HyperOperator,
factor: &Array2<f64>,
dense_acc: &mut [Array2<f64>],
terms: &mut Vec<(usize, f64, &'a dyn HyperOperator)>,
) {
if weight == 0.0 {
return;
}
if let Some(composite) = op.as_composite() {
if let Some(dense) = composite.dense.as_ref() {
dense_acc[out_idx].scaled_add(weight, &dense_projected_matrix(dense, factor));
}
for inner in &composite.operators {
collect_projected_matrix_terms(
out_idx,
weight,
inner.as_ref(),
factor,
dense_acc,
terms,
);
}
} else if let Some(weighted) = op.as_weighted() {
for (term_weight, inner) in &weighted.terms {
collect_projected_matrix_terms(
out_idx,
weight * *term_weight,
inner.as_ref(),
factor,
dense_acc,
terms,
);
}
} else {
terms.push((out_idx, weight, op));
}
}
fn trace_projected_operator_terms_batched(
n_out: usize,
terms: &[(usize, f64, &dyn HyperOperator)],
factor: &Array2<f64>,
cache: &ProjectedFactorCache,
) -> Vec<f64> {
let mut out = vec![0.0_f64; n_out];
let mut handled = vec![false; terms.len()];
for i in 0..terms.len() {
if handled[i] {
continue;
}
let Some(impl_i) = terms[i].2.as_implicit() else {
continue;
};
let mut group = vec![i];
handled[i] = true;
for j in (i + 1)..terms.len() {
if handled[j] {
continue;
}
if let Some(impl_j) = terms[j].2.as_implicit()
&& Arc::ptr_eq(&impl_i.implicit_deriv, &impl_j.implicit_deriv)
&& Arc::ptr_eq(&impl_i.x_design, &impl_j.x_design)
&& Arc::ptr_eq(impl_i.w_diag.as_arc(), impl_j.w_diag.as_arc())
&& impl_i.p == impl_j.p
{
group.push(j);
handled[j] = true;
}
}
let lead = terms[group[0]].2.as_implicit().unwrap();
let xf = lead.cached_xf(factor, cache);
let axes: Vec<(usize, &Array2<f64>, Option<&Array1<f64>>)> = group
.iter()
.map(|&term_idx| {
let op = terms[term_idx].2.as_implicit().unwrap();
(op.axis, &op.s_psi, op.c_x_psi_beta.as_deref())
})
.collect();
let values = lead.trace_projected_factor_all_axes_with_xf(factor, xf.view(), &axes);
for (&term_idx, value) in group.iter().zip(values.iter()) {
let (out_idx, weight, _) = terms[term_idx];
out[out_idx] += weight * *value;
}
}
for (i, (out_idx, weight, op)) in terms.iter().enumerate() {
if handled[i] {
continue;
}
out[*out_idx] += *weight * op.trace_projected_factor_cached(factor, cache);
}
out
}
fn projected_operator_terms_batched(
n_out: usize,
terms: &[(usize, f64, &dyn HyperOperator)],
factor: &Array2<f64>,
cache: &ProjectedFactorCache,
) -> Vec<Array2<f64>> {
let rank = factor.ncols();
let mut out: Vec<Array2<f64>> = (0..n_out)
.map(|_| Array2::<f64>::zeros((rank, rank)))
.collect();
for (out_idx, weight, op) in terms.iter() {
let projected = op.projected_matrix_cached(factor, cache);
out[*out_idx].scaled_add(*weight, &projected);
}
out
}
fn project_hyper_operators_batched(
n_out: usize,
terms: &[(usize, f64, &dyn HyperOperator)],
factor: &Array2<f64>,
cache: &ProjectedFactorCache,
) -> Vec<Array2<f64>> {
projected_operator_terms_batched(n_out, terms, factor, cache)
}
fn trace_logdet_drifts_projected_factor_batched(
drifts: &[DriftDerivResult],
factor: &Array2<f64>,
cache: &ProjectedFactorCache,
) -> Vec<f64> {
let mut out = vec![0.0_f64; drifts.len()];
let mut terms: Vec<(usize, f64, &dyn HyperOperator)> = Vec::new();
for (idx, drift) in drifts.iter().enumerate() {
match drift {
DriftDerivResult::Dense(matrix) => {
out[idx] += dense_trace_projected_factor(matrix, factor);
}
DriftDerivResult::Operator(op) => {
collect_projected_trace_terms(idx, 1.0, op.as_ref(), factor, &mut out, &mut terms);
}
}
}
let batched = trace_projected_operator_terms_batched(drifts.len(), &terms, factor, cache);
for (dst, value) in out.iter_mut().zip(batched) {
*dst += value;
}
out
}
fn dense_spectral_trace_logdet_drifts_batched(
ds: &DenseSpectralOperator,
drifts: &[DriftDerivResult],
) -> Vec<f64> {
trace_logdet_drifts_projected_factor_batched(drifts, &ds.g_factor, &ds.projected_factor_cache)
}
fn penalty_subspace_trace_factor(kernel: &PenaltySubspaceTrace) -> Array2<f64> {
let (evals, evecs) = kernel
.h_proj_inverse
.eigh(faer::Side::Lower)
.expect("PenaltySubspaceTrace kernel factor eigendecomposition failed");
let r = evals.len();
let mut root = evecs.clone();
for col in 0..r {
let scale = evals[col].max(0.0).sqrt();
for row in 0..r {
root[[row, col]] *= scale;
}
}
crate::faer_ndarray::fast_ab(&kernel.u_s, &root)
}
fn penalty_subspace_trace_drifts_batched(
kernel: &PenaltySubspaceTrace,
drifts: &[DriftDerivResult],
) -> Vec<f64> {
let factor = penalty_subspace_trace_factor(kernel);
let cache = ProjectedFactorCache::default();
trace_logdet_drifts_projected_factor_batched(drifts, &factor, &cache)
}
fn penalty_subspace_reduce_drifts_batched(
kernel: &PenaltySubspaceTrace,
drifts: &[DriftDerivResult],
) -> Vec<Array2<f64>> {
drifts
.iter()
.map(|drift| match drift {
DriftDerivResult::Dense(matrix) => kernel.reduce(matrix),
DriftDerivResult::Operator(op) => kernel.reduce_operator(op.as_ref()),
})
.collect()
}
fn dense_spectral_trace_logdet_operators_batched(
ds: &DenseSpectralOperator,
operators: &[Arc<dyn HyperOperator>],
) -> Vec<f64> {
if operators.is_empty() {
return Vec::new();
}
if log::log_enabled!(log::Level::Info) {
let start = std::time::Instant::now();
let out =
trace_projected_factors_batched(operators, &ds.g_factor, &ds.projected_factor_cache);
let implicit_count = operators.iter().filter(|op| op.is_implicit()).count();
dense_spectral_stage_log(
&format!(
"DenseSpectralOperator::trace_logdet_operators_batched dim={} rank={} ops={} implicit_ops={}",
ds.n_dim,
ds.g_factor.ncols(),
operators.len(),
implicit_count,
),
start.elapsed().as_secs_f64(),
);
out
} else {
trace_projected_factors_batched(operators, &ds.g_factor, &ds.projected_factor_cache)
}
}
impl HyperOperator for CompositeHyperOperator {
fn as_composite(&self) -> Option<&CompositeHyperOperator> {
Some(self)
}
fn dim(&self) -> usize {
self.dim_hint
}
fn mul_vec(&self, v: &Array1<f64>) -> Array1<f64> {
let mut out = Array1::<f64>::zeros(v.len());
self.mul_vec_into(v.view(), out.view_mut());
out
}
fn mul_vec_view(&self, v: ArrayView1<'_, f64>) -> Array1<f64> {
let mut out = Array1::<f64>::zeros(v.len());
self.mul_vec_into(v, out.view_mut());
out
}
fn mul_vec_into(&self, v: ArrayView1<'_, f64>, mut out: ArrayViewMut1<'_, f64>) {
if self.dense.is_none() && self.operators.len() == 1 {
self.operators[0].mul_vec_into(v, out);
return;
}
out.fill(0.0);
if let Some(dense) = self.dense.as_ref() {
dense_matvec_into(dense, v, out.view_mut());
}
for op in &self.operators {
op.scaled_add_mul_vec(v, 1.0, out.view_mut());
}
}
fn mul_basis_columns_into(&self, start: usize, mut out: ArrayViewMut2<'_, f64>) {
if self.dense.is_none() && self.operators.len() == 1 {
self.operators[0].mul_basis_columns_into(start, out);
return;
}
out.fill(0.0);
let cols = out.ncols();
let end = start + cols;
if let Some(dense) = self.dense.as_ref() {
out += &dense.slice(ndarray::s![.., start..end]);
}
let mut work = Array2::<f64>::zeros((out.nrows(), cols));
for op in &self.operators {
op.mul_basis_columns_into(start, work.view_mut());
out += &work;
}
}
fn scaled_add_mul_vec(
&self,
v: ArrayView1<'_, f64>,
scale: f64,
mut out: ArrayViewMut1<'_, f64>,
) {
if scale == 0.0 {
return;
}
if self.dense.is_none() && self.operators.len() == 1 {
self.operators[0].scaled_add_mul_vec(v, scale, out);
return;
}
if let Some(dense) = self.dense.as_ref() {
dense_matvec_scaled_add_into(dense, v, scale, out.view_mut());
}
for op in &self.operators {
op.scaled_add_mul_vec(v, scale, out.view_mut());
}
}
fn mul_mat(&self, factor: &Array2<f64>) -> Array2<f64> {
if self.dense.is_none() && self.operators.len() == 1 {
return self.operators[0].mul_mat(factor);
}
let p = factor.nrows();
let k = factor.ncols();
let mut out = Array2::<f64>::zeros((p, k));
if let Some(dense) = self.dense.as_ref() {
out += &dense.dot(factor);
}
for op in &self.operators {
out += &op.mul_mat(factor);
}
out
}
fn trace_projected_factor(&self, factor: &Array2<f64>) -> f64 {
if self.dense.is_none() && self.operators.len() == 1 {
return self.operators[0].trace_projected_factor(factor);
}
let mut trace = 0.0;
if let Some(dense) = self.dense.as_ref() {
let dense_factor = dense.dot(factor);
trace += factor
.iter()
.zip(dense_factor.iter())
.map(|(&f, &bf)| f * bf)
.sum::<f64>();
}
trace += composite_trace_implicit_batched(&self.operators, factor, None);
trace
}
fn trace_projected_factor_cached(
&self,
factor: &Array2<f64>,
cache: &ProjectedFactorCache,
) -> f64 {
if self.dense.is_none() && self.operators.len() == 1 {
return self.operators[0].trace_projected_factor_cached(factor, cache);
}
let mut trace = 0.0;
if let Some(dense) = self.dense.as_ref() {
let dense_factor = dense.dot(factor);
trace += factor
.iter()
.zip(dense_factor.iter())
.map(|(&f, &bf)| f * bf)
.sum::<f64>();
}
trace += composite_trace_implicit_batched(&self.operators, factor, Some(cache));
trace
}
fn projected_matrix(&self, factor: &Array2<f64>) -> Array2<f64> {
if self.dense.is_none() && self.operators.len() == 1 {
return self.operators[0].projected_matrix(factor);
}
let rank = factor.ncols();
let mut projected = Array2::<f64>::zeros((rank, rank));
if let Some(dense) = self.dense.as_ref() {
let mf = crate::faer_ndarray::fast_ab(dense, factor);
projected += &crate::faer_ndarray::fast_atb(factor, &mf);
}
for op in &self.operators {
projected += &op.projected_matrix(factor);
}
projected
}
fn projected_matrix_cached(
&self,
factor: &Array2<f64>,
cache: &ProjectedFactorCache,
) -> Array2<f64> {
if self.dense.is_none() && self.operators.len() == 1 {
return self.operators[0].projected_matrix_cached(factor, cache);
}
let rank = factor.ncols();
let mut projected = Array2::<f64>::zeros((rank, rank));
if let Some(dense) = self.dense.as_ref() {
let mf = crate::faer_ndarray::fast_ab(dense, factor);
projected += &crate::faer_ndarray::fast_atb(factor, &mf);
}
for op in &self.operators {
projected += &op.projected_matrix_cached(factor, cache);
}
projected
}
fn bilinear(&self, v: &Array1<f64>, u: &Array1<f64>) -> f64 {
let mut total = 0.0;
if let Some(dense) = self.dense.as_ref() {
total += dense_bilinear(dense, v.view(), u.view());
}
for op in &self.operators {
total += op.bilinear(v, u);
}
total
}
fn bilinear_view(&self, v: ArrayView1<'_, f64>, u: ArrayView1<'_, f64>) -> f64 {
let mut total = 0.0;
if let Some(dense) = self.dense.as_ref() {
total += dense_bilinear(dense, v, u);
}
for op in &self.operators {
total += op.bilinear_view(v, u);
}
total
}
fn to_dense(&self) -> Array2<f64> {
let mut out = self
.dense
.clone()
.unwrap_or_else(|| Array2::<f64>::zeros((self.dim_hint, self.dim_hint)));
for op in &self.operators {
out += &op.to_dense();
}
out
}
fn is_implicit(&self) -> bool {
self.operators.iter().any(|op| op.is_implicit())
}
}
#[derive(Clone)]
pub struct BlockLocalDrift {
pub local: Array2<f64>,
pub start: usize,
pub end: usize,
pub total_dim: usize,
}
impl HyperOperator for BlockLocalDrift {
fn dim(&self) -> usize {
self.total_dim
}
fn mul_vec(&self, v: &Array1<f64>) -> Array1<f64> {
let mut out = Array1::zeros(v.len());
self.mul_vec_into(v.view(), out.view_mut());
out
}
fn mul_vec_view(&self, v: ArrayView1<'_, f64>) -> Array1<f64> {
let mut out = Array1::zeros(v.len());
self.mul_vec_into(v, out.view_mut());
out
}
fn mul_vec_into(&self, v: ArrayView1<'_, f64>, mut out: ArrayViewMut1<'_, f64>) {
out.fill(0.0);
let v_block = v.slice(ndarray::s![self.start..self.end]);
let out_block = out.slice_mut(ndarray::s![self.start..self.end]);
dense_matvec_into(&self.local, v_block, out_block);
}
fn mul_basis_columns_into(&self, start: usize, mut out: ArrayViewMut2<'_, f64>) {
out.fill(0.0);
let global_end = start + out.ncols();
let col_start = start.max(self.start);
let col_end = global_end.min(self.end);
if col_start >= col_end {
return;
}
let local_col_start = col_start - self.start;
let local_col_end = col_end - self.start;
let out_col_start = col_start - start;
let out_col_end = col_end - start;
out.slice_mut(ndarray::s![
self.start..self.end,
out_col_start..out_col_end
])
.assign(
&self
.local
.slice(ndarray::s![.., local_col_start..local_col_end]),
);
}
fn scaled_add_mul_vec(
&self,
v: ArrayView1<'_, f64>,
scale: f64,
mut out: ArrayViewMut1<'_, f64>,
) {
if scale == 0.0 {
return;
}
let v_block = v.slice(ndarray::s![self.start..self.end]);
let out_block = out.slice_mut(ndarray::s![self.start..self.end]);
dense_matvec_scaled_add_into(&self.local, v_block, scale, out_block);
}
fn bilinear(&self, v: &Array1<f64>, u: &Array1<f64>) -> f64 {
let v_block = v.slice(ndarray::s![self.start..self.end]);
let u_block = u.slice(ndarray::s![self.start..self.end]);
u_block.dot(&self.local.dot(&v_block))
}
fn bilinear_view(&self, v: ArrayView1<'_, f64>, u: ArrayView1<'_, f64>) -> f64 {
let v_block = v.slice(ndarray::s![self.start..self.end]);
let u_block = u.slice(ndarray::s![self.start..self.end]);
let mut total = 0.0;
for (row, u_value) in self.local.rows().into_iter().zip(u_block.iter().copied()) {
let mut row_dot = 0.0;
for (entry, v_value) in row.iter().copied().zip(v_block.iter().copied()) {
row_dot += entry * v_value;
}
total += u_value * row_dot;
}
total
}
fn to_dense(&self) -> Array2<f64> {
let p = self.total_dim;
let mut out = Array2::zeros((p, p));
out.slice_mut(ndarray::s![self.start..self.end, self.start..self.end])
.assign(&self.local);
out
}
fn is_implicit(&self) -> bool {
false
}
fn block_local_data(&self) -> Option<(&Array2<f64>, usize, usize)> {
Some((&self.local, self.start, self.end))
}
}
#[derive(Clone)]
pub struct HyperCoordDrift {
pub dense: Option<Array2<f64>>,
pub block_local: Option<BlockLocalDrift>,
pub operator: Option<Arc<dyn HyperOperator>>,
}
impl HyperCoordDrift {
pub fn none() -> Self {
Self {
dense: None,
block_local: None,
operator: None,
}
}
pub fn from_dense(dense: Array2<f64>) -> Self {
Self {
dense: Some(dense),
block_local: None,
operator: None,
}
}
pub fn from_operator(operator: Arc<dyn HyperOperator>) -> Self {
Self {
dense: None,
block_local: None,
operator: Some(operator),
}
}
pub fn from_parts(
dense: Option<Array2<f64>>,
operator: Option<Arc<dyn HyperOperator>>,
) -> Self {
let dense = dense.filter(|mat| !(operator.is_some() && mat.is_empty()));
Self {
dense,
block_local: None,
operator,
}
}
pub fn from_block_local_and_operator(
local: Array2<f64>,
start: usize,
end: usize,
total_dim: usize,
operator: Option<Arc<dyn HyperOperator>>,
) -> Self {
Self {
dense: None,
block_local: Some(BlockLocalDrift {
local,
start,
end,
total_dim,
}),
operator,
}
}
pub fn has_operator(&self) -> bool {
self.operator.is_some()
}
pub fn uses_operator_fast_path(&self) -> bool {
self.operator.is_some() || self.block_local.is_some()
}
pub fn operator_ref(&self) -> Option<&dyn HyperOperator> {
self.operator.as_ref().map(Arc::as_ref)
}
pub fn materialize(&self) -> Array2<f64> {
let p = self.infer_dim();
if p == 0 {
return Array2::zeros((0, 0));
}
let mut out = self.dense.clone().unwrap_or_else(|| Array2::zeros((p, p)));
if let Some(bl) = &self.block_local {
out.slice_mut(ndarray::s![bl.start..bl.end, bl.start..bl.end])
.scaled_add(1.0, &bl.local);
}
if let Some(op) = &self.operator {
out += &op.to_dense();
}
out
}
pub fn apply(&self, v: &Array1<f64>) -> Array1<f64> {
let mut out = Array1::zeros(v.len());
self.scaled_add_apply(v.view(), 1.0, &mut out);
out
}
pub fn scaled_add_apply(&self, v: ArrayView1<'_, f64>, scale: f64, out: &mut Array1<f64>) {
assert_eq!(v.len(), out.len());
if scale == 0.0 {
return;
}
if let Some(dense) = &self.dense {
dense_matvec_scaled_add_into(dense, v, scale, out.view_mut());
}
if let Some(bl) = &self.block_local {
let v_block = v.slice(ndarray::s![bl.start..bl.end]);
let out_block = out.slice_mut(ndarray::s![bl.start..bl.end]);
dense_matvec_scaled_add_into(&bl.local, v_block, scale, out_block);
}
if let Some(op) = &self.operator {
op.scaled_add_mul_vec(v, scale, out.view_mut());
}
}
fn infer_dim(&self) -> usize {
if let Some(d) = &self.dense {
return d.nrows();
}
if let Some(op) = &self.operator {
return op.dim();
}
if let Some(bl) = &self.block_local {
return bl.total_dim;
}
0
}
}
mod implicit_matvec_scratch {
use std::cell::RefCell;
pub(super) struct Scratch {
pub x_v: Vec<f64>,
pub n_work: Vec<f64>,
pub p_work: Vec<f64>,
}
impl Scratch {
const fn new() -> Self {
Self {
x_v: Vec::new(),
n_work: Vec::new(),
p_work: Vec::new(),
}
}
}
thread_local! {
static SCRATCH: RefCell<Scratch> = const { RefCell::new(Scratch::new()) };
}
pub(super) fn with<R>(f: impl FnOnce(&mut Scratch) -> R) -> R {
SCRATCH.with(|cell| f(&mut cell.borrow_mut()))
}
}
pub struct ImplicitHyperOperator {
pub implicit_deriv: std::sync::Arc<crate::terms::basis::ImplicitDesignPsiDerivative>,
pub axis: usize,
pub(crate) x_design: std::sync::Arc<DesignMatrix>,
pub(crate) w_diag: crate::matrix::SignedWeightsArc,
pub s_psi: Array2<f64>,
pub(crate) p: usize,
pub c_x_psi_beta: Option<std::sync::Arc<Array1<f64>>>,
}
impl HyperOperator for ImplicitHyperOperator {
fn dim(&self) -> usize {
self.p
}
fn mul_vec(&self, v: &Array1<f64>) -> Array1<f64> {
let mut out = Array1::<f64>::zeros(self.p);
self.mul_vec_into(v.view(), out.view_mut());
out
}
fn mul_vec_view(&self, v: ArrayView1<'_, f64>) -> Array1<f64> {
let mut out = Array1::<f64>::zeros(self.p);
self.mul_vec_into(v, out.view_mut());
out
}
fn mul_vec_into(&self, v: ArrayView1<'_, f64>, out: ArrayViewMut1<'_, f64>) {
assert_eq!(v.len(), self.p);
let n_obs = self.w_diag.len();
implicit_matvec_scratch::with(|s| {
s.x_v.clear();
s.x_v.resize(n_obs, 0.0);
s.n_work.clear();
s.n_work.resize(n_obs, 0.0);
s.p_work.clear();
s.p_work.resize(self.p, 0.0);
let mut x_v_view = ndarray::ArrayViewMut1::from(s.x_v.as_mut_slice());
let n_work_view = ndarray::ArrayViewMut1::from(s.n_work.as_mut_slice());
let p_work_view = ndarray::ArrayViewMut1::from(s.p_work.as_mut_slice());
design_matrix_apply_view_into(&self.x_design, v, x_v_view.view_mut());
self.matvec_with_shared_xz_into(x_v_view.view(), v, out, n_work_view, p_work_view);
});
}
fn mul_basis_columns_into(&self, start: usize, mut out: ArrayViewMut2<'_, f64>) {
let cols = out.ncols();
assert!(start + cols <= self.p);
let n_obs = self.w_diag.len();
let mut basis = Array1::<f64>::zeros(self.p);
let mut x_col = Array1::<f64>::zeros(n_obs);
let mut dx_col = Array1::<f64>::zeros(n_obs);
let mut weighted = Array1::<f64>::zeros(n_obs);
let mut term = Array1::<f64>::zeros(self.p);
for local_col in 0..cols {
let global_col = start + local_col;
let mut out_col = out.column_mut(local_col);
out_col.assign(&self.s_psi.column(global_col));
design_matrix_column_into(&self.x_design, global_col, x_col.view_mut());
Zip::from(weighted.view_mut())
.and(self.w_diag.view())
.and(x_col.view())
.par_for_each(|dst, &w, &x| *dst = w * x);
term.assign(
&self
.implicit_deriv
.transpose_mul(self.axis, &weighted.view())
.expect("radial scalar evaluation failed during implicit hyper transpose_mul"),
);
out_col += &term;
basis[global_col] = 1.0;
dx_col.assign(
&self
.implicit_deriv
.forward_mul(self.axis, &basis.view())
.expect("radial scalar evaluation failed during implicit hyper forward_mul"),
);
basis[global_col] = 0.0;
Zip::from(weighted.view_mut())
.and(self.w_diag.view())
.and(dx_col.view())
.par_for_each(|dst, &w, &dx| *dst = w * dx);
design_matrix_transpose_apply_view_into(
&self.x_design,
weighted.view(),
term.view_mut(),
);
out_col += &term;
self.accumulate_c_correction_xt_into(
x_col.view(),
weighted.view_mut(),
term.view_mut(),
out_col,
);
}
}
fn bilinear(&self, v: &Array1<f64>, u: &Array1<f64>) -> f64 {
self.bilinear_view(v.view(), u.view())
}
fn bilinear_view(&self, v: ArrayView1<'_, f64>, u: ArrayView1<'_, f64>) -> f64 {
assert_eq!(v.len(), self.p);
assert_eq!(u.len(), self.p);
let x_v = design_matrix_apply_view(&self.x_design, v);
let x_u = design_matrix_apply_view(&self.x_design, u);
let dx_v = self
.implicit_deriv
.forward_mul(self.axis, &v)
.expect("radial scalar evaluation failed during implicit hyper forward_mul");
let dx_u = self
.implicit_deriv
.forward_mul(self.axis, &u)
.expect("radial scalar evaluation failed during implicit hyper forward_mul");
let w = &*self.w_diag;
let mut design = 0.0;
for i in 0..w.len() {
design += dx_v[i] * w[i] * x_u[i];
design += dx_u[i] * w[i] * x_v[i];
}
design += self.c_correction_bilinear(&x_v, &x_u);
let penalty = dense_bilinear(&self.s_psi, v, u);
design + penalty
}
fn is_implicit(&self) -> bool {
true
}
fn as_implicit(&self) -> Option<&ImplicitHyperOperator> {
Some(self)
}
fn trace_projected_factor(&self, factor: &Array2<f64>) -> f64 {
assert_eq!(factor.nrows(), self.p);
let n_obs = self.w_diag.len();
let rank = factor.ncols();
if rank == 0 || n_obs == 0 {
return 0.0;
}
let xf = self.compute_xf(factor);
self.trace_projected_factor_with_xf(factor, xf.view())
}
fn trace_projected_factor_cached(
&self,
factor: &Array2<f64>,
cache: &ProjectedFactorCache,
) -> f64 {
assert_eq!(factor.nrows(), self.p);
let n_obs = self.w_diag.len();
let rank = factor.ncols();
if rank == 0 || n_obs == 0 {
return 0.0;
}
let xf = self.cached_xf(factor, cache);
self.trace_projected_factor_with_xf(factor, xf.view())
}
}
fn byte_balanced_row_chunk(cols: usize, n_rows: usize) -> usize {
const TARGET_BYTES: usize = 8 * 1024 * 1024;
const MIN_CHUNK_ROWS: usize = 512;
let bytes_per_row = cols.max(1) * std::mem::size_of::<f64>();
(TARGET_BYTES / bytes_per_row)
.max(MIN_CHUNK_ROWS)
.min(n_rows)
}
impl ImplicitHyperOperator {
fn compute_xf(&self, factor: &Array2<f64>) -> Array2<f64> {
let n_obs = self.w_diag.len();
let rank = factor.ncols();
let mut xf = Array2::<f64>::zeros((n_obs, rank));
let chunk_rows = byte_balanced_row_chunk(self.p + rank, n_obs);
let mut start = 0usize;
while start < n_obs {
let end = (start + chunk_rows).min(n_obs);
let rows = self
.x_design
.try_row_chunk(start..end)
.unwrap_or_else(|err| {
reml_contract_panic(format!(
"ImplicitHyperOperator::compute_xf row chunk failed: {err}"
))
});
let block = crate::faer_ndarray::fast_ab(&rows, factor);
xf.slice_mut(ndarray::s![start..end, ..]).assign(&block);
start = end;
}
xf
}
fn cached_xf(&self, factor: &Array2<f64>, cache: &ProjectedFactorCache) -> Arc<Array2<f64>> {
let design_id = Arc::as_ptr(&self.x_design) as usize;
let key = ProjectedFactorKey::from_factor_view(design_id, factor.view());
cache.get_or_insert_with(key, || self.compute_xf(factor))
}
fn trace_projected_factor_with_xf(&self, factor: &Array2<f64>, xf: ArrayView2<'_, f64>) -> f64 {
let rank = factor.ncols();
let n_obs = self.w_diag.len();
assert_eq!(xf.dim(), (n_obs, rank));
let u_knot = self.implicit_deriv.unproject_matrix(&factor.view());
let chunk_rows = byte_balanced_row_chunk(self.p + rank, n_obs);
let w = self.w_diag.as_ref();
let c_opt = self.c_x_psi_beta.as_ref().map(|arc| arc.as_ref());
let mut design_total = 0.0_f64;
let mut correction_total = 0.0_f64;
let mut start = 0usize;
while start < n_obs {
let end = (start + chunk_rows).min(n_obs);
let chunk_n = end - start;
let xf_chunk = xf.slice(ndarray::s![start..end, ..]);
let kd_chunk = self
.implicit_deriv
.row_chunk_first_raw(self.axis, start..end)
.expect("radial scalar evaluation failed during implicit hyper forward_mul_matrix");
let dxf_chunk = crate::faer_ndarray::fast_ab(&kd_chunk, &u_knot);
for i_local in 0..chunk_n {
let i = start + i_local;
let w_i = w[i];
let dxf_row = dxf_chunk.row(i_local);
let xf_row = xf_chunk.row(i_local);
for k in 0..rank {
design_total += dxf_row[k] * w_i * xf_row[k];
}
if let Some(c) = c_opt {
let c_i = c[i];
for k in 0..rank {
let v = xf_row[k];
correction_total += c_i * v * v;
}
}
}
start = end;
}
let s_f = self.s_psi.dot(factor);
let penalty: f64 = factor.iter().zip(s_f.iter()).map(|(&f, &s)| f * s).sum();
2.0 * design_total + correction_total + penalty
}
fn trace_projected_factor_all_axes_with_xf(
&self,
factor: &Array2<f64>,
xf: ArrayView2<'_, f64>,
axes: &[(usize, &Array2<f64>, Option<&Array1<f64>>)],
) -> Vec<f64> {
let rank = factor.ncols();
let n_obs = self.w_diag.len();
assert_eq!(xf.dim(), (n_obs, rank));
let u_knot = self.implicit_deriv.unproject_matrix(&factor.view());
let chunk_rows = byte_balanced_row_chunk(self.p + rank, n_obs.max(1));
let w = self.w_diag.as_ref();
let mut design_totals = vec![0.0_f64; axes.len()];
let mut correction_totals = vec![0.0_f64; axes.len()];
let mut start = 0usize;
while start < n_obs {
let end = (start + chunk_rows).min(n_obs);
let chunk_n = end - start;
let xf_chunk = xf.slice(ndarray::s![start..end, ..]);
for (axis_idx, (axis, _s_psi, c_opt_axis)) in axes.iter().enumerate() {
let kd_chunk = self
.implicit_deriv
.row_chunk_first_raw(*axis, start..end)
.expect(
"radial scalar evaluation failed during \
trace_projected_factor_all_axes_with_xf",
);
let dxf_chunk = crate::faer_ndarray::fast_ab(&kd_chunk, &u_knot);
for i_local in 0..chunk_n {
let i = start + i_local;
let w_i = w[i];
let dxf_row = dxf_chunk.row(i_local);
let xf_row = xf_chunk.row(i_local);
for k in 0..rank {
design_totals[axis_idx] += dxf_row[k] * w_i * xf_row[k];
}
if let Some(c) = c_opt_axis {
let c_i = c[i];
for k in 0..rank {
let v = xf_row[k];
correction_totals[axis_idx] += c_i * v * v;
}
}
}
}
start = end;
}
axes.iter()
.enumerate()
.map(|(idx, (_axis, s_psi, _c_opt_axis))| {
let s_f = s_psi.dot(factor);
let penalty: f64 = factor.iter().zip(s_f.iter()).map(|(&f, &s)| f * s).sum();
2.0 * design_totals[idx] + correction_totals[idx] + penalty
})
.collect()
}
fn accumulate_c_correction_xt_into(
&self,
x_col: ArrayView1<'_, f64>,
mut n_work: ArrayViewMut1<'_, f64>,
mut p_work: ArrayViewMut1<'_, f64>,
mut out_col: ArrayViewMut1<'_, f64>,
) {
let Some(c_x_psi_beta) = self.c_x_psi_beta.as_ref() else {
return;
};
let c = c_x_psi_beta.as_ref();
assert_eq!(x_col.len(), c.len());
assert_eq!(n_work.len(), c.len());
assert_eq!(p_work.len(), self.p);
for i in 0..c.len() {
n_work[i] = c[i] * x_col[i];
}
design_matrix_transpose_apply_view_into(&self.x_design, n_work.view(), p_work.view_mut());
out_col += &p_work;
}
fn c_correction_bilinear(&self, x_v: &Array1<f64>, x_u: &Array1<f64>) -> f64 {
let Some(c_x_psi_beta) = self.c_x_psi_beta.as_ref() else {
return 0.0;
};
x_v.iter()
.zip(x_u.iter())
.zip(c_x_psi_beta.iter())
.map(|((&xv, &xu), &c)| xv * c * xu)
.sum()
}
pub fn bilinear_with_shared_x(
&self,
x_vec: &Array1<f64>,
y_vec: &Array1<f64>,
z: &Array1<f64>,
u: &Array1<f64>,
) -> f64 {
let dx_z = self
.implicit_deriv
.forward_mul(self.axis, &z.view())
.expect("radial scalar evaluation failed during implicit hyper forward_mul");
let dx_u = self
.implicit_deriv
.forward_mul(self.axis, &u.view())
.expect("radial scalar evaluation failed during implicit hyper forward_mul");
let mut design = 0.0f64;
let w = &*self.w_diag;
for i in 0..x_vec.len() {
let wi = w[i];
design += dx_z[i] * wi * y_vec[i];
design += dx_u[i] * wi * x_vec[i];
}
if let Some(c_x_psi_beta) = self.c_x_psi_beta.as_ref() {
let c = c_x_psi_beta.as_ref();
for i in 0..x_vec.len() {
design += y_vec[i] * c[i] * x_vec[i];
}
}
let penalty = dense_bilinear(&self.s_psi, z.view(), u.view());
design + penalty
}
pub fn matvec_with_shared_xz_into(
&self,
x_vec: ArrayView1<'_, f64>,
z: ArrayView1<'_, f64>,
mut out: ArrayViewMut1<'_, f64>,
mut n_work: ArrayViewMut1<'_, f64>,
mut p_work: ArrayViewMut1<'_, f64>,
) {
assert_eq!(z.len(), self.p);
assert_eq!(out.len(), self.p);
assert_eq!(n_work.len(), self.w_diag.len());
assert_eq!(p_work.len(), self.p);
let w = &*self.w_diag;
for i in 0..w.len() {
n_work[i] = w[i] * x_vec[i];
}
let term1 = self
.implicit_deriv
.transpose_mul(self.axis, &n_work.view())
.expect("radial scalar evaluation failed during implicit hyper transpose_mul");
out.assign(&term1);
let dx_z = self
.implicit_deriv
.forward_mul(self.axis, &z)
.expect("radial scalar evaluation failed during implicit hyper forward_mul");
for i in 0..w.len() {
n_work[i] = w[i] * dx_z[i];
}
design_matrix_transpose_apply_view_into(&self.x_design, n_work.view(), p_work.view_mut());
out += &p_work;
dense_matvec_into(&self.s_psi, z, p_work.view_mut());
out += &p_work;
if let Some(c_x_psi_beta) = self.c_x_psi_beta.as_ref() {
let c = c_x_psi_beta.as_ref();
for i in 0..w.len() {
n_work[i] = c[i] * x_vec[i];
}
design_matrix_transpose_apply_view_into(
&self.x_design,
n_work.view(),
p_work.view_mut(),
);
out += &p_work;
}
}
}
pub struct SparseDirectionalHyperOperator {
pub(crate) x_tau: super::HyperDesignDerivative,
pub(crate) x_design: DesignMatrix,
pub(crate) w_diag: crate::matrix::SignedWeightsArc,
pub(crate) s_tau: Array2<f64>,
pub(crate) c_x_tau_beta: Option<Array1<f64>>,
pub(crate) firth_hphi_tau_partial: Option<Array2<f64>>,
pub(crate) p: usize,
}
impl HyperOperator for SparseDirectionalHyperOperator {
fn dim(&self) -> usize {
self.p
}
fn mul_vec(&self, v: &Array1<f64>) -> Array1<f64> {
assert_eq!(v.len(), self.p);
let x_v = self.x_design.matrixvectormultiply(v);
let w_x_v = &*self.w_diag * &x_v;
let term1 = self
.x_tau
.transpose_mul_original(&w_x_v)
.expect("SparseDirectionalHyperOperator transpose product should be shape-consistent");
let x_tau_v = self
.x_tau
.forward_mul_original(v)
.expect("SparseDirectionalHyperOperator forward product should be shape-consistent");
let w_x_tau_v = &*self.w_diag * &x_tau_v;
let term2 = self.x_design.transpose_vector_multiply(&w_x_tau_v);
let term3 = self.s_tau.dot(v);
let mut out = term1 + term2 + term3;
if let Some(c_x_tau_beta) = self.c_x_tau_beta.as_ref() {
let weighted = c_x_tau_beta * &x_v;
out += &self.x_design.transpose_vector_multiply(&weighted);
}
if let Some(hphi_tau_partial) = self.firth_hphi_tau_partial.as_ref() {
out -= &hphi_tau_partial.dot(v);
}
out
}
fn is_implicit(&self) -> bool {
false
}
fn as_sparse_directional(&self) -> Option<&SparseDirectionalHyperOperator> {
Some(self)
}
}
pub struct GlmCurvatureCorrectionOperator {
pub(crate) x_design: DesignMatrix,
pub(crate) neg_c_xv: Array1<f64>,
pub(crate) p: usize,
}
impl HyperOperator for GlmCurvatureCorrectionOperator {
fn dim(&self) -> usize {
self.p
}
fn mul_vec(&self, v: &Array1<f64>) -> Array1<f64> {
assert_eq!(v.len(), self.p);
let x_v = self.x_design.matrixvectormultiply(v);
let weighted = &self.neg_c_xv * &x_v;
self.x_design.transpose_vector_multiply(&weighted)
}
fn is_implicit(&self) -> bool {
false
}
}
#[derive(Clone, Debug)]
pub struct PenaltyLogdetDerivs {
pub value: f64,
pub first: Array1<f64>,
pub second: Option<Array2<f64>>,
}
#[derive(Clone, Debug)]
pub enum PenaltyCoordinate {
DenseRoot(Array2<f64>),
DenseRootCentered {
root: Array2<f64>,
prior_mean: Array1<f64>,
},
BlockRoot {
root: Array2<f64>,
start: usize,
end: usize,
total_dim: usize,
},
BlockRootCentered {
root: Array2<f64>,
start: usize,
end: usize,
total_dim: usize,
prior_mean: Array1<f64>,
},
KroneckerMarginal {
eigenvalues: Vec<Array1<f64>>,
dim_index: usize,
marginal_dims: Vec<usize>,
total_dim: usize,
},
}
impl PenaltyCoordinate {
pub fn from_dense_root(root: Array2<f64>) -> Self {
Self::DenseRoot(root)
}
pub fn from_dense_root_with_mean(root: Array2<f64>, prior_mean: Array1<f64>) -> Self {
assert_eq!(root.ncols(), prior_mean.len());
if prior_mean.iter().all(|&value| value == 0.0) {
Self::DenseRoot(root)
} else {
Self::DenseRootCentered { root, prior_mean }
}
}
pub fn from_block_root(root: Array2<f64>, start: usize, end: usize, total_dim: usize) -> Self {
assert_eq!(
root.ncols(),
end.saturating_sub(start),
"block prior root column count must match block width"
);
assert!(
end <= total_dim,
"block prior root end exceeds total dimension: start={start}, end={end}, total_dim={total_dim}, root_dim={:?}",
root.dim()
);
Self::BlockRoot {
root,
start,
end,
total_dim,
}
}
pub fn from_block_root_with_mean(
root: Array2<f64>,
start: usize,
end: usize,
total_dim: usize,
prior_mean: Array1<f64>,
) -> Self {
assert_eq!(
root.ncols(),
end.saturating_sub(start),
"centered block prior root column count must match block width"
);
assert_eq!(
prior_mean.len(),
end.saturating_sub(start),
"centered block prior mean length must match block width"
);
assert!(
end <= total_dim,
"centered block prior root end exceeds total dimension: start={start}, end={end}, total_dim={total_dim}, root_dim={:?}, prior_mean_len={}",
root.dim(),
prior_mean.len()
);
if prior_mean.iter().all(|&value| value == 0.0) {
Self::from_block_root(root, start, end, total_dim)
} else {
Self::BlockRootCentered {
root,
start,
end,
total_dim,
prior_mean,
}
}
}
pub fn rank(&self) -> usize {
match self {
Self::DenseRoot(root)
| Self::DenseRootCentered { root, .. }
| Self::BlockRoot { root, .. }
| Self::BlockRootCentered { root, .. } => root.nrows(),
Self::KroneckerMarginal {
eigenvalues,
dim_index,
..
} => {
let nz = eigenvalues[*dim_index]
.iter()
.filter(|&&v| v.abs() > 1e-12)
.count();
let other: usize = eigenvalues
.iter()
.enumerate()
.filter(|&(j, _)| j != *dim_index)
.map(|(_, e)| e.len())
.product::<usize>()
.max(1);
nz * other
}
}
}
pub fn dim(&self) -> usize {
match self {
Self::DenseRoot(root) | Self::DenseRootCentered { root, .. } => root.ncols(),
Self::BlockRoot { total_dim, .. }
| Self::BlockRootCentered { total_dim, .. }
| Self::KroneckerMarginal { total_dim, .. } => *total_dim,
}
}
pub fn uses_operator_fast_path(&self) -> bool {
matches!(
self,
Self::BlockRoot { .. }
| Self::BlockRootCentered { .. }
| Self::KroneckerMarginal { .. }
)
}
pub fn project_into_subspace(&self, z: &Array2<f64>) -> Self {
assert_eq!(
z.nrows(),
self.dim(),
"PenaltyCoordinate::project_into_subspace: free-basis row count {} does not match coordinate dimension {}",
z.nrows(),
self.dim()
);
match self {
Self::DenseRoot(root) => Self::DenseRoot(root.dot(z)),
Self::DenseRootCentered { root, prior_mean } => {
Self::from_dense_root_with_mean(root.dot(z), z.t().dot(prior_mean))
}
Self::BlockRoot {
root, start, end, ..
} => {
let z_block = z.slice(ndarray::s![*start..*end, ..]);
Self::DenseRoot(root.dot(&z_block))
}
Self::BlockRootCentered {
root,
start,
end,
prior_mean,
..
} => {
let z_block = z.slice(ndarray::s![*start..*end, ..]);
let z_block_owned = z_block.to_owned();
Self::from_dense_root_with_mean(
root.dot(&z_block_owned),
z_block_owned.t().dot(prior_mean),
)
}
Self::KroneckerMarginal { .. } => reml_contract_panic(
"PenaltyCoordinate::project_into_subspace: Kronecker-factored \
coordinates do not co-occur with linear-inequality active sets \
(box/monotone constraints lower to dense/block roots)",
),
}
}
fn apply_root(&self, beta: &Array1<f64>) -> Array1<f64> {
assert_eq!(beta.len(), self.dim());
match self {
Self::DenseRoot(root) | Self::DenseRootCentered { root, .. } => root.dot(beta),
Self::BlockRoot {
root, start, end, ..
}
| Self::BlockRootCentered {
root, start, end, ..
} => root.dot(&beta.slice(ndarray::s![*start..*end])),
Self::KroneckerMarginal { .. } => {
reml_contract_panic(
"apply_root not supported for KroneckerMarginal; use apply_penalty directly",
);
}
}
}
pub fn apply_penalty(&self, beta: &Array1<f64>, scale: f64) -> Array1<f64> {
assert_eq!(beta.len(), self.dim());
let mut out = Array1::<f64>::zeros(self.dim());
self.apply_penalty_view_into(beta.view(), scale, out.view_mut());
out
}
pub fn apply_penalty_view_into(
&self,
beta: ArrayView1<'_, f64>,
scale: f64,
mut out: ArrayViewMut1<'_, f64>,
) {
assert_eq!(beta.len(), self.dim());
assert_eq!(out.len(), self.dim());
out.fill(0.0);
self.scaled_add_penalty_view(beta, scale, out);
}
pub fn scaled_add_penalty_view(
&self,
beta: ArrayView1<'_, f64>,
scale: f64,
mut out: ArrayViewMut1<'_, f64>,
) {
assert_eq!(beta.len(), self.dim());
assert_eq!(out.len(), self.dim());
if scale == 0.0 {
return;
}
match self {
Self::DenseRoot(_)
| Self::DenseRootCentered { .. }
| Self::BlockRoot { .. }
| Self::BlockRootCentered { .. } => match self {
Self::DenseRoot(root) | Self::DenseRootCentered { root, .. } => {
let mut root_beta = Array1::<f64>::zeros(root.nrows());
dense_matvec_into(root, beta, root_beta.view_mut());
dense_transpose_matvec_scaled_add_into(
root,
root_beta.view(),
scale,
out.view_mut(),
);
}
Self::BlockRoot {
root,
start,
end,
total_dim: _,
}
| Self::BlockRootCentered {
root,
start,
end,
total_dim: _,
..
} => {
let beta_block = beta.slice(ndarray::s![*start..*end]);
let mut root_beta = Array1::<f64>::zeros(root.nrows());
dense_matvec_into(root, beta_block, root_beta.view_mut());
let out_block = out.slice_mut(ndarray::s![*start..*end]);
dense_transpose_matvec_scaled_add_into(
root,
root_beta.view(),
scale,
out_block,
);
}
Self::KroneckerMarginal { .. } => {}
},
Self::KroneckerMarginal {
eigenvalues,
dim_index,
marginal_dims,
total_dim,
} => {
let k = *dim_index;
let q_k = marginal_dims[k];
let stride_k: usize = marginal_dims[k + 1..]
.iter()
.copied()
.product::<usize>()
.max(1);
let outer_size: usize =
marginal_dims[..k].iter().copied().product::<usize>().max(1);
let inner_size = stride_k;
let eigs = &eigenvalues[k];
assert_eq!(
outer_size * q_k * stride_k,
*total_dim,
"KroneckerMarginal dimension mismatch in apply"
);
for outer in 0..outer_size {
for j in 0..q_k {
let mu = eigs[j] * scale;
if mu == 0.0 {
continue;
}
let base = outer * q_k * stride_k + j * stride_k;
for inner in 0..inner_size {
let idx = base + inner;
out[idx] += mu * beta[idx];
}
}
}
}
}
}
pub fn quadratic(&self, beta: &Array1<f64>, scale: f64) -> f64 {
match self {
Self::DenseRoot(_)
| Self::DenseRootCentered { .. }
| Self::BlockRoot { .. }
| Self::BlockRootCentered { .. } => {
let root_beta = self.apply_root(beta);
scale * root_beta.dot(&root_beta)
}
Self::KroneckerMarginal {
eigenvalues,
dim_index,
marginal_dims,
..
} => {
let k = *dim_index;
let q_k = marginal_dims[k];
let stride_k: usize = marginal_dims[k + 1..]
.iter()
.copied()
.product::<usize>()
.max(1);
let outer_size: usize =
marginal_dims[..k].iter().copied().product::<usize>().max(1);
let inner_size = stride_k;
let eigs = &eigenvalues[k];
let mut sum = 0.0;
for outer in 0..outer_size {
for j in 0..q_k {
let mu = eigs[j];
if mu == 0.0 {
continue;
}
let base = outer * q_k * stride_k + j * stride_k;
for inner in 0..inner_size {
let v = beta[base + inner];
sum += mu * v * v;
}
}
}
sum * scale
}
}
}
pub fn apply_shifted_penalty(&self, beta: &Array1<f64>, scale: f64) -> Array1<f64> {
match self {
Self::DenseRootCentered { root, prior_mean } => {
let centered = beta - prior_mean;
let root_beta = root.dot(¢ered);
let mut out = root.t().dot(&root_beta);
out *= scale;
out
}
Self::BlockRootCentered {
root,
start,
end,
total_dim,
prior_mean,
} => {
let mut out = Array1::<f64>::zeros(*total_dim);
let beta_block = beta.slice(ndarray::s![*start..*end]);
let centered = beta_block.to_owned() - prior_mean;
let root_beta = root.dot(¢ered);
let mut block = root.t().dot(&root_beta);
block *= scale;
out.slice_mut(ndarray::s![*start..*end]).assign(&block);
out
}
_ => self.apply_penalty(beta, scale),
}
}
pub fn shifted_quadratic(&self, beta: &Array1<f64>, scale: f64) -> f64 {
match self {
Self::DenseRootCentered { root, prior_mean } => {
let centered = beta - prior_mean;
let root_beta = root.dot(¢ered);
scale * root_beta.dot(&root_beta)
}
Self::BlockRootCentered {
root,
start,
end,
prior_mean,
..
} => {
let beta_block = beta.slice(ndarray::s![*start..*end]);
let centered = beta_block.to_owned() - prior_mean;
let root_beta = root.dot(¢ered);
scale * root_beta.dot(&root_beta)
}
_ => self.quadratic(beta, scale),
}
}
pub fn scaled_dense_matrix(&self, scale: f64) -> Array2<f64> {
match self {
Self::DenseRoot(root) | Self::DenseRootCentered { root, .. } => {
let mut out = root.t().dot(root);
out *= scale;
out
}
Self::BlockRoot {
root,
start,
end,
total_dim,
}
| Self::BlockRootCentered {
root,
start,
end,
total_dim,
..
} => {
let mut out = Array2::<f64>::zeros((*total_dim, *total_dim));
let mut block = root.t().dot(root);
block *= scale;
out.slice_mut(ndarray::s![*start..*end, *start..*end])
.assign(&block);
out
}
Self::KroneckerMarginal {
eigenvalues,
dim_index,
marginal_dims,
total_dim,
} => {
let k = *dim_index;
let q_k = marginal_dims[k];
let stride_k: usize = marginal_dims[k + 1..]
.iter()
.copied()
.product::<usize>()
.max(1);
let outer_size: usize =
marginal_dims[..k].iter().copied().product::<usize>().max(1);
let eigs = &eigenvalues[k];
assert_eq!(
outer_size * q_k * stride_k,
*total_dim,
"KroneckerMarginal dimension mismatch in to_dense"
);
let mut out = Array2::<f64>::zeros((*total_dim, *total_dim));
for outer in 0..outer_size {
for j in 0..q_k {
let mu = eigs[j] * scale;
let base = outer * q_k * stride_k + j * stride_k;
for inner in 0..stride_k {
let idx = base + inner;
out[[idx, idx]] = mu;
}
}
}
out
}
}
}
pub fn scaled_block_local(&self, scale: f64) -> (Array2<f64>, usize, usize) {
match self {
Self::DenseRoot(root) | Self::DenseRootCentered { root, .. } => {
let mut out = root.t().dot(root);
out *= scale;
let p = out.nrows();
(out, 0, p)
}
Self::BlockRoot {
root, start, end, ..
}
| Self::BlockRootCentered {
root, start, end, ..
} => {
let mut block = root.t().dot(root);
block *= scale;
(block, *start, *end)
}
Self::KroneckerMarginal { total_dim, .. } => {
let mat = self.scaled_dense_matrix(scale);
(mat, 0, *total_dim)
}
}
}
pub fn is_block_local(&self) -> bool {
matches!(
self,
Self::BlockRoot { .. }
| Self::BlockRootCentered { .. }
| Self::KroneckerMarginal { .. }
)
}
pub fn scaled_matvec(&self, v: &Array1<f64>, scale: f64) -> Array1<f64> {
match self {
Self::DenseRoot(root) | Self::DenseRootCentered { root, .. } => {
let root_v = root.dot(v);
let mut out = root.t().dot(&root_v);
out *= scale;
out
}
Self::BlockRoot {
root, start, end, ..
}
| Self::BlockRootCentered {
root, start, end, ..
} => {
let mut out = Array1::zeros(v.len());
let v_block = v.slice(ndarray::s![*start..*end]);
let root_v = root.dot(&v_block);
let mut block_result = root.t().dot(&root_v);
block_result *= scale;
out.slice_mut(ndarray::s![*start..*end])
.assign(&block_result);
out
}
Self::KroneckerMarginal { .. } => {
self.apply_penalty(v, scale)
}
}
}
}
#[derive(Clone, Debug)]
pub struct PenaltySubspaceTrace {
pub u_s: Array2<f64>,
pub h_proj_inverse: Array2<f64>,
}
impl PenaltySubspaceTrace {
pub fn trace_projected_logdet(&self, a: &Array2<f64>) -> f64 {
crate::construction::trace_penalty_covariance_in_orthogonal_basis(
a,
&self.u_s,
&self.h_proj_inverse,
)
}
pub fn reduce(&self, a: &Array2<f64>) -> Array2<f64> {
let u_s_t_a = crate::faer_ndarray::fast_atb(&self.u_s, a);
crate::faer_ndarray::fast_ab(&u_s_t_a, &self.u_s)
}
pub fn trace_projected_logdet_reduced(&self, r_mat: &Array2<f64>) -> f64 {
crate::construction::trace_reduced_penalty_covariance(r_mat, &self.h_proj_inverse)
}
pub fn trace_projected_logdet_cross_reduced(&self, ra: &Array2<f64>, rb: &Array2<f64>) -> f64 {
let left = self.h_proj_inverse.dot(ra);
let right = self.h_proj_inverse.dot(rb);
trace_matrix_product(&left, &right)
}
pub fn reduce_operator<O>(&self, a: &O) -> Array2<f64>
where
O: HyperOperator + ?Sized,
{
let au = a.mul_mat(&self.u_s);
crate::faer_ndarray::fast_atb(&self.u_s, &au)
}
pub fn trace_operator<O>(&self, a: &O) -> f64
where
O: HyperOperator + ?Sized,
{
self.trace_projected_logdet_reduced(&self.reduce_operator(a))
}
pub fn xt_projected_kernel_x_diagonal(&self, x: &DesignMatrix) -> Array1<f64> {
let n = x.nrows();
let p = x.ncols();
let r = self.u_s.ncols();
assert_eq!(self.u_s.nrows(), p);
assert_eq!(self.h_proj_inverse.nrows(), r);
assert_eq!(self.h_proj_inverse.ncols(), r);
let block = {
const TARGET_CHUNK_FLOATS: usize = 1 << 16;
(TARGET_CHUNK_FLOATS / p.max(1)).clamp(1, n.max(1))
};
let mut h = Array1::<f64>::zeros(n);
let mut start = 0usize;
while start < n {
let end = (start + block).min(n);
let rows = x.try_row_chunk(start..end).unwrap_or_else(|err| {
reml_contract_panic(format!(
"xt_projected_kernel_x_diagonal: row chunk failed: {err}"
))
});
let z_chunk = crate::faer_ndarray::fast_ab(&rows, &self.u_s);
for (i, row_z) in z_chunk.outer_iter().enumerate() {
let mut acc = 0.0;
for (z_a, h_row) in row_z
.iter()
.copied()
.zip(self.h_proj_inverse.rows().into_iter())
{
let mut inner = 0.0;
for (h_value, z_b) in h_row.iter().copied().zip(row_z.iter().copied()) {
inner += h_value * z_b;
}
acc += z_a * inner;
}
h[start + i] = acc;
}
start = end;
}
h
}
pub fn bilinear_pseudo_inverse(&self, a: &Array1<f64>, b: &Array1<f64>) -> f64 {
let proj_a = crate::faer_ndarray::fast_atv(&self.u_s, a);
let proj_b = crate::faer_ndarray::fast_atv(&self.u_s, b);
let h_proj_inv_b = self.h_proj_inverse.dot(&proj_b);
proj_a.dot(&h_proj_inv_b)
}
pub fn project_onto_subspace(&self, a: &Array1<f64>) -> Array1<f64> {
let proj_a = crate::faer_ndarray::fast_atv(&self.u_s, a);
crate::faer_ndarray::fast_av(&self.u_s, &proj_a)
}
pub fn apply_pseudo_inverse(&self, a: &Array1<f64>) -> Array1<f64> {
self.sensitivity().apply(a)
}
pub fn sensitivity(&self) -> crate::solver::sensitivity::FitSensitivity<'_> {
crate::solver::sensitivity::FitSensitivity::from_projected(&self.u_s, &self.h_proj_inverse)
}
pub fn with_active_constraints<'a>(
&'a self,
a_act: ndarray::ArrayView2<'a, f64>,
) -> ConstrainedSubspaceKernel<'a> {
let k_active = a_act.nrows();
if k_active == 0 {
return ConstrainedSubspaceKernel {
kernel: self,
z: Array2::zeros((0, self.u_s.nrows())),
a_act,
m_inv: Array2::zeros((0, 0)),
k_active: 0,
};
}
let p = self.u_s.nrows();
let mut z = Array2::<f64>::zeros((p, k_active));
for j in 0..k_active {
let a_row = a_act.row(j).to_owned();
let k_s_a_row = self.apply_pseudo_inverse(&a_row);
z.column_mut(j).assign(&k_s_a_row);
}
let mut m = a_act.dot(&z);
for i in 0..k_active {
for j in 0..i {
let avg = 0.5 * (m[[i, j]] + m[[j, i]]);
m[[i, j]] = avg;
m[[j, i]] = avg;
}
}
let (evals, evecs) = m
.eigh(faer::Side::Lower)
.unwrap_or_else(|_| (Array1::<f64>::zeros(k_active), Array2::<f64>::eye(k_active)));
let sigma_max = evals.iter().copied().fold(0.0_f64, f64::max).max(0.0);
let tol = f64::EPSILON * (k_active as f64) * sigma_max.max(1.0);
let mut m_inv = Array2::<f64>::zeros((k_active, k_active));
let mut dropped = 0usize;
for q in 0..k_active {
if evals[q] > tol {
let inv_sigma = 1.0 / evals[q];
for i in 0..k_active {
for j in 0..k_active {
m_inv[[i, j]] += inv_sigma * evecs[[i, q]] * evecs[[j, q]];
}
}
} else {
dropped += 1;
}
}
if dropped > 0 {
log::debug!(
"[constrained-subspace kernel] dropped {} of {} active-constraint directions \
(rank-deficient on range(S₊)); pseudo-inverse threshold = {:.3e}",
dropped,
k_active,
tol,
);
}
ConstrainedSubspaceKernel {
kernel: self,
z,
a_act,
m_inv,
k_active,
}
}
}
pub struct ConstrainedSubspaceKernel<'a> {
kernel: &'a PenaltySubspaceTrace,
z: Array2<f64>,
a_act: ndarray::ArrayView2<'a, f64>,
m_inv: Array2<f64>,
k_active: usize,
}
impl<'a> ConstrainedSubspaceKernel<'a> {
pub fn apply_pseudo_inverse(&self, a: &Array1<f64>) -> Array1<f64> {
let v_s = self.kernel.apply_pseudo_inverse(a);
if self.k_active == 0 {
return v_s;
}
let t = self.a_act.dot(&v_s);
let mu = self.m_inv.dot(&t);
let correction = self.z.dot(&mu);
v_s - &correction
}
pub fn has_active_constraints(&self) -> bool {
self.k_active > 0
}
}
const THETA_MODE_RESPONSE_TANGENCY_GATE: f64 = 1e-6;
pub(crate) struct ThetaModeResponseKernel<'s> {
hop: &'s dyn HessianOperator,
constrained: Option<ConstrainedSubspaceKernel<'s>>,
}
impl<'s> ThetaModeResponseKernel<'s> {
pub(crate) fn select(
subspace: Option<&'s PenaltySubspaceTrace>,
active_constraints: Option<&'s ActiveLinearConstraintBlock>,
hop: &'s dyn HessianOperator,
) -> Self {
let constrained = match (subspace, active_constraints) {
(Some(kernel), Some(block)) => {
let ck = kernel.with_active_constraints(block.a.view());
ck.has_active_constraints().then_some(ck)
}
_ => None,
};
Self { hop, constrained }
}
pub(crate) fn respond_one(&self, rhs: &Array1<f64>) -> Array1<f64> {
match self.constrained.as_ref() {
Some(ck) => {
let v = ck.apply_pseudo_inverse(rhs);
self.certify_tangency(ck, &v);
v
}
None => self.hop.solve(rhs),
}
}
pub(crate) fn respond_stack(&self, rhs_stack: &Array2<f64>) -> Array2<f64> {
match self.constrained.as_ref() {
Some(ck) => {
let mut out = Array2::<f64>::zeros(rhs_stack.raw_dim());
for (j, col) in rhs_stack.columns().into_iter().enumerate() {
let v = ck.apply_pseudo_inverse(&col.to_owned());
self.certify_tangency(ck, &v);
out.column_mut(j).assign(&v);
}
out
}
None => self.hop.solve_multi(rhs_stack),
}
}
fn certify_tangency(&self, ck: &ConstrainedSubspaceKernel<'_>, v: &Array1<f64>) {
let residual = ck.a_act.dot(v);
for (row, r) in residual.iter().enumerate() {
let scale: f64 = ck
.a_act
.row(row)
.iter()
.zip(v.iter())
.map(|(a, x)| (a * x).abs())
.sum();
if r.abs() > THETA_MODE_RESPONSE_TANGENCY_GATE * (scale + f64::EPSILON) {
log::warn!(
"[CERTIFICATE warning] atom \"theta_mode_response\": constrained IFT \
mode response left ker(A_act) — active row {row} residual {:.3e} \
exceeds gate {:.1e}·{:.3e}; the lifted kernel K_T and its emission \
have desynced (#931 pass-2 invariant)",
r.abs(),
THETA_MODE_RESPONSE_TANGENCY_GATE,
scale,
);
}
}
}
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum KktResidualSubspace {
ActiveProjected,
ReducedRange,
}
#[derive(Clone, Debug)]
pub struct ProjectedKktResidual {
residual: Array1<f64>,
subspace: KktResidualSubspace,
residual_tol: Option<f64>,
free_rank: Option<usize>,
}
impl ProjectedKktResidual {
pub(crate) fn from_active_projected(residual: Array1<f64>) -> Self {
Self {
residual,
subspace: KktResidualSubspace::ActiveProjected,
residual_tol: None,
free_rank: None,
}
}
fn from_reduced_range(residual: Array1<f64>) -> Self {
Self {
residual,
subspace: KktResidualSubspace::ReducedRange,
residual_tol: None,
free_rank: None,
}
}
pub(crate) fn with_metadata(mut self, residual_tol: f64, free_rank: usize) -> Self {
self.residual_tol = Some(residual_tol);
self.free_rank = Some(free_rank);
self
}
pub fn as_array(&self) -> &Array1<f64> {
&self.residual
}
pub fn subspace(&self) -> KktResidualSubspace {
self.subspace
}
fn projected_into_reduced_range(&self, kernel: &PenaltySubspaceTrace) -> Result<Self, String> {
match self.subspace {
KktResidualSubspace::ReducedRange => Ok(self.clone()),
KktResidualSubspace::ActiveProjected => {
let reduced_residual = kernel.project_onto_subspace(&self.residual);
let dropped_inf = self
.residual
.iter()
.zip(reduced_residual.iter())
.map(|(full, reduced)| (full - reduced).abs())
.fold(0.0_f64, f64::max);
let residual_inf = self
.residual
.iter()
.map(|value| value.abs())
.fold(0.0_f64, f64::max);
const DEFAULT_KKT_RESIDUAL_REL_TOL: f64 = 1e-10;
let tol = self
.residual_tol
.unwrap_or_else(|| DEFAULT_KKT_RESIDUAL_REL_TOL * (1.0 + residual_inf));
let gate = tol;
if dropped_inf > gate {
return Err(format!(
"projected KKT residual contains unresolved mass outside the reduced \
Hessian/penalty range: |r_A - r_R|∞={dropped_inf:.3e} > tol={gate:.3e}; \
range-projected IFT correction is valid only after the null direction is \
explicitly removed/fixed or after the active-projected residual is small"
));
}
let mut reduced = Self::from_reduced_range(reduced_residual);
reduced.residual_tol = self.residual_tol;
reduced.free_rank = self.free_rank;
Ok(reduced)
}
}
}
pub fn residual_tol(&self) -> Option<f64> {
self.residual_tol
}
pub fn free_rank(&self) -> Option<usize> {
self.free_rank
}
}
#[derive(Clone, Debug)]
pub enum DispersionHandling {
ProfiledGaussian,
Fixed {
phi: f64,
include_logdet_h: bool,
include_logdet_s: bool,
},
}
pub struct InnerSolution<'dp> {
pub log_likelihood: f64,
pub penalty_quadratic: f64,
pub hessian_op: Arc<dyn HessianOperator>,
pub beta: Array1<f64>,
pub penalty_coords: Vec<PenaltyCoordinate>,
pub penalty_logdet: PenaltyLogdetDerivs,
pub deriv_provider: Box<dyn HessianDerivativeProvider + 'dp>,
pub tk_correction: f64,
pub tk_gradient: Option<Array1<f64>>,
pub firth: Option<ExactJeffreysTerm>,
pub hessian_logdet_correction: f64,
pub penalty_subspace_trace: Option<Arc<PenaltySubspaceTrace>>,
pub rho_curvature_scale: f64,
pub rho_prior: crate::types::RhoPrior,
pub n_observations: usize,
pub nullspace_dim: f64,
pub gaussian_weight_log_sum_half: f64,
pub dispersion: DispersionHandling,
pub ext_coords: Vec<HyperCoord>,
pub ext_coord_pair_fn: Option<Box<dyn Fn(usize, usize) -> HyperCoordPair + Send + Sync>>,
pub rho_ext_pair_fn: Option<Box<dyn Fn(usize, usize) -> HyperCoordPair + Send + Sync>>,
pub fixed_drift_deriv: Option<FixedDriftDerivFn>,
pub contracted_psi_second_order: Option<ContractedPsiSecondOrderFn>,
pub barrier_config: Option<BarrierConfig>,
pub kkt_residual: Option<ProjectedKktResidual>,
pub active_constraints: Option<Arc<ActiveLinearConstraintBlock>>,
pub stochastic_trace_state: Arc<Mutex<StochasticTraceState>>,
}
#[derive(Clone, Debug)]
pub struct ActiveLinearConstraintBlock {
pub a: Array2<f64>,
}
pub struct InnerSolutionBuilder<'dp> {
log_likelihood: f64,
penalty_quadratic: f64,
hessian_op: Arc<dyn HessianOperator>,
beta: Array1<f64>,
penalty_coords: Vec<PenaltyCoordinate>,
penalty_logdet: PenaltyLogdetDerivs,
n_observations: usize,
dispersion: DispersionHandling,
deriv_provider: Box<dyn HessianDerivativeProvider + 'dp>,
tk_correction: f64,
tk_gradient: Option<Array1<f64>>,
firth: Option<ExactJeffreysTerm>,
hessian_logdet_correction: f64,
penalty_subspace_trace: Option<Arc<PenaltySubspaceTrace>>,
rho_curvature_scale: f64,
rho_prior: crate::types::RhoPrior,
nullspace_dim_override: Option<f64>,
ext_coords: Vec<HyperCoord>,
ext_coord_pair_fn: Option<Box<dyn Fn(usize, usize) -> HyperCoordPair + Send + Sync>>,
rho_ext_pair_fn: Option<Box<dyn Fn(usize, usize) -> HyperCoordPair + Send + Sync>>,
fixed_drift_deriv: Option<FixedDriftDerivFn>,
contracted_psi_second_order: Option<ContractedPsiSecondOrderFn>,
barrier_config: Option<BarrierConfig>,
kkt_residual: Option<ProjectedKktResidual>,
active_constraints: Option<Arc<ActiveLinearConstraintBlock>>,
gaussian_weight_log_sum_half: f64,
}
impl<'dp> InnerSolutionBuilder<'dp> {
pub fn new(
log_likelihood: f64,
penalty_quadratic: f64,
beta: Array1<f64>,
n_observations: usize,
hessian_op: Arc<dyn HessianOperator>,
penalty_coords: Vec<PenaltyCoordinate>,
penalty_logdet: PenaltyLogdetDerivs,
dispersion: DispersionHandling,
) -> Self {
Self {
log_likelihood,
penalty_quadratic,
hessian_op,
beta,
penalty_coords,
penalty_logdet,
n_observations,
dispersion,
deriv_provider: Box::new(GaussianDerivatives),
tk_correction: 0.0,
tk_gradient: None,
firth: None,
hessian_logdet_correction: 0.0,
penalty_subspace_trace: None,
rho_curvature_scale: 1.0,
rho_prior: crate::types::RhoPrior::Flat,
nullspace_dim_override: None,
ext_coords: Vec::new(),
ext_coord_pair_fn: None,
rho_ext_pair_fn: None,
fixed_drift_deriv: None,
contracted_psi_second_order: None,
barrier_config: None,
kkt_residual: None,
active_constraints: None,
gaussian_weight_log_sum_half: 0.0,
}
}
pub fn deriv_provider(mut self, p: Box<dyn HessianDerivativeProvider + 'dp>) -> Self {
self.deriv_provider = p;
self
}
pub fn tk(mut self, correction: f64, gradient: Option<Array1<f64>>) -> Self {
self.tk_correction = correction;
self.tk_gradient = gradient;
self
}
pub fn firth_term(mut self, term: Option<ExactJeffreysTerm>) -> Self {
self.firth = term;
self
}
pub fn hessian_logdet_correction(mut self, correction: f64) -> Self {
self.hessian_logdet_correction = correction;
self
}
pub fn penalty_subspace_trace(mut self, kernel: Option<Arc<PenaltySubspaceTrace>>) -> Self {
self.penalty_subspace_trace = kernel;
self
}
pub fn rho_curvature_scale(mut self, scale: f64) -> Self {
self.rho_curvature_scale = scale;
self
}
pub fn rho_prior(mut self, prior: crate::types::RhoPrior) -> Self {
self.rho_prior = prior;
self
}
pub fn nullspace_dim_override(mut self, dim: f64) -> Self {
self.nullspace_dim_override = Some(dim);
self
}
pub fn ext_coords(mut self, coords: Vec<HyperCoord>) -> Self {
self.ext_coords = coords;
self
}
pub fn ext_coord_pair_fn(
mut self,
f: Box<dyn Fn(usize, usize) -> HyperCoordPair + Send + Sync>,
) -> Self {
self.ext_coord_pair_fn = Some(f);
self
}
pub fn rho_ext_pair_fn(
mut self,
f: Box<dyn Fn(usize, usize) -> HyperCoordPair + Send + Sync>,
) -> Self {
self.rho_ext_pair_fn = Some(f);
self
}
pub fn fixed_drift_deriv(mut self, f: FixedDriftDerivFn) -> Self {
self.fixed_drift_deriv = Some(f);
self
}
pub fn contracted_psi_second_order(mut self, f: Option<ContractedPsiSecondOrderFn>) -> Self {
self.contracted_psi_second_order = f;
self
}
pub fn barrier_config(mut self, config: Option<BarrierConfig>) -> Self {
self.barrier_config = config;
self
}
pub fn kkt_residual(mut self, residual: Option<ProjectedKktResidual>) -> Self {
self.kkt_residual = residual;
self
}
pub fn active_constraints(mut self, block: Option<Arc<ActiveLinearConstraintBlock>>) -> Self {
self.active_constraints = block;
self
}
pub fn build(self) -> InnerSolution<'dp> {
let beta_dim = self.beta.len();
let penalty_dim = self.penalty_coords.len();
assert_eq!(
self.hessian_op.dim(),
beta_dim,
"InnerSolutionBuilder: Hessian dimension {} does not match beta length {}",
self.hessian_op.dim(),
beta_dim
);
for (idx, coord) in self.penalty_coords.iter().enumerate() {
assert_eq!(
coord.dim(),
beta_dim,
"InnerSolutionBuilder: penalty coordinate {idx} has dimension {} but beta length is {}",
coord.dim(),
beta_dim
);
}
assert_eq!(
self.penalty_logdet.first.len(),
penalty_dim,
"InnerSolutionBuilder: penalty logdet first-derivative length {} does not match penalty coordinate count {}",
self.penalty_logdet.first.len(),
penalty_dim
);
if let Some(second) = self.penalty_logdet.second.as_ref() {
assert!(
second.nrows() == penalty_dim && second.ncols() == penalty_dim,
"InnerSolutionBuilder: penalty logdet Hessian shape {}x{} does not match penalty coordinate count {}",
second.nrows(),
second.ncols(),
penalty_dim
);
}
if let Some(tk_gradient) = self.tk_gradient.as_ref() {
assert_eq!(
tk_gradient.len(),
penalty_dim,
"InnerSolutionBuilder: TK gradient length {} does not match penalty coordinate count {}",
tk_gradient.len(),
penalty_dim
);
}
if let Some(barrier_config) = self.barrier_config.as_ref() {
assert_eq!(
barrier_config.constrained_indices.len(),
barrier_config.lower_bounds.len(),
"InnerSolutionBuilder: barrier constrained index count {} does not match lower-bound count {}",
barrier_config.constrained_indices.len(),
barrier_config.lower_bounds.len()
);
assert_eq!(
barrier_config.constrained_indices.len(),
barrier_config.bound_signs.len(),
"InnerSolutionBuilder: barrier constrained index count {} does not match bound-direction count {}",
barrier_config.constrained_indices.len(),
barrier_config.bound_signs.len()
);
assert!(
barrier_config.tau.is_finite() && barrier_config.tau >= 0.0,
"InnerSolutionBuilder: barrier tau must be finite and non-negative, got {}",
barrier_config.tau
);
for ((&idx, &lower_bound), &sign) in barrier_config
.constrained_indices
.iter()
.zip(barrier_config.lower_bounds.iter())
.zip(barrier_config.bound_signs.iter())
{
assert!(
idx < beta_dim,
"InnerSolutionBuilder: barrier constrained index {idx} out of bounds for beta length {beta_dim}"
);
assert!(
lower_bound.is_finite(),
"InnerSolutionBuilder: barrier lower bound for beta[{idx}] must be finite, got {lower_bound}"
);
assert!(
sign == 1.0 || sign == -1.0,
"InnerSolutionBuilder: barrier bound direction for beta[{idx}] must be ±1, got {sign}"
);
}
}
if let Some(active_constraints) = self.active_constraints.as_ref() {
assert_eq!(
active_constraints.a.ncols(),
beta_dim,
"InnerSolutionBuilder: active constraint width {} does not match beta length {}",
active_constraints.a.ncols(),
beta_dim
);
}
let nullspace_dim = self.nullspace_dim_override.unwrap_or_else(|| {
let penalty_rank: usize = self
.penalty_coords
.iter()
.map(PenaltyCoordinate::rank)
.sum();
beta_dim.saturating_sub(penalty_rank) as f64
});
InnerSolution {
log_likelihood: self.log_likelihood,
penalty_quadratic: self.penalty_quadratic,
hessian_op: self.hessian_op,
beta: self.beta,
penalty_coords: self.penalty_coords,
penalty_logdet: self.penalty_logdet,
deriv_provider: self.deriv_provider,
tk_correction: self.tk_correction,
tk_gradient: self.tk_gradient,
firth: self.firth,
hessian_logdet_correction: self.hessian_logdet_correction,
penalty_subspace_trace: self.penalty_subspace_trace,
rho_curvature_scale: self.rho_curvature_scale,
rho_prior: self.rho_prior,
n_observations: self.n_observations,
nullspace_dim,
gaussian_weight_log_sum_half: self.gaussian_weight_log_sum_half,
dispersion: self.dispersion,
ext_coords: self.ext_coords,
ext_coord_pair_fn: self.ext_coord_pair_fn,
rho_ext_pair_fn: self.rho_ext_pair_fn,
fixed_drift_deriv: self.fixed_drift_deriv,
contracted_psi_second_order: self.contracted_psi_second_order,
barrier_config: self.barrier_config,
kkt_residual: self.kkt_residual,
active_constraints: self.active_constraints,
stochastic_trace_state: Arc::new(Mutex::new(StochasticTraceState::default())),
}
}
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum EvalMode {
ValueOnly,
ValueAndGradient,
ValueGradientHessian,
}
#[derive(Debug)]
pub struct RemlLamlResult {
pub cost: f64,
pub ift_residual_energy: Option<f64>,
pub inner_polish_step: Option<Array1<f64>>,
pub gradient: Option<Array1<f64>>,
pub hessian: crate::solver::outer_strategy::HessianResult,
pub rho_mode_response_cols: Option<Array2<f64>>,
pub ext_mode_response_cols: Option<Array2<f64>>,
}
use crate::solver::estimate::smooth_floor_dp;
const DENOM_RIDGE: f64 = 1e-8;
fn penalty_a_k_beta(coord: &PenaltyCoordinate, beta: &Array1<f64>, lambda: f64) -> Array1<f64> {
coord.apply_shifted_penalty(beta, lambda)
}
fn penalty_a_k_quadratic(coord: &PenaltyCoordinate, beta: &Array1<f64>, lambda: f64) -> f64 {
coord.shifted_quadratic(beta, lambda)
}
#[inline]
fn rho_curvature_lambda(solution: &InnerSolution<'_>, lambda: f64) -> f64 {
solution.rho_curvature_scale * lambda
}
fn penalty_coord_to_operator(coord: PenaltyCoordinate, scale: f64) -> Arc<dyn HyperOperator> {
struct OwnedPenaltyHyperOperator {
coord: PenaltyCoordinate,
scale: f64,
}
impl HyperOperator for OwnedPenaltyHyperOperator {
fn dim(&self) -> usize {
self.coord.dim()
}
fn mul_vec(&self, v: &Array1<f64>) -> Array1<f64> {
let mut out = Array1::<f64>::zeros(v.len());
self.mul_vec_into(v.view(), out.view_mut());
out
}
fn mul_vec_view(&self, v: ArrayView1<'_, f64>) -> Array1<f64> {
let mut out = Array1::<f64>::zeros(v.len());
self.mul_vec_into(v, out.view_mut());
out
}
fn mul_vec_into(&self, v: ArrayView1<'_, f64>, out: ArrayViewMut1<'_, f64>) {
self.coord.apply_penalty_view_into(v, self.scale, out);
}
fn scaled_add_mul_vec(
&self,
v: ArrayView1<'_, f64>,
scale: f64,
out: ArrayViewMut1<'_, f64>,
) {
if scale == 0.0 {
return;
}
self.coord
.scaled_add_penalty_view(v, scale * self.scale, out);
}
fn to_dense(&self) -> Array2<f64> {
self.coord.scaled_dense_matrix(self.scale)
}
fn is_implicit(&self) -> bool {
false
}
}
Arc::new(OwnedPenaltyHyperOperator { coord, scale })
}
fn penalty_total_drift_result(
coord: &PenaltyCoordinate,
scale: f64,
correction: Option<&DriftDerivResult>,
) -> DriftDerivResult {
match correction {
Some(DriftDerivResult::Dense(corr)) => {
if coord.uses_operator_fast_path() {
DriftDerivResult::Operator(Arc::new(CompositeHyperOperator {
dense: Some(corr.clone()),
operators: vec![penalty_coord_to_operator(coord.clone(), scale)],
dim_hint: coord.dim(),
}))
} else {
let mut dense = coord.scaled_dense_matrix(scale);
dense += corr;
DriftDerivResult::Dense(dense)
}
}
Some(DriftDerivResult::Operator(corr_op)) => {
DriftDerivResult::Operator(Arc::new(CompositeHyperOperator {
dense: if coord.uses_operator_fast_path() {
None
} else {
Some(coord.scaled_dense_matrix(scale))
},
operators: {
let mut ops = vec![Arc::clone(corr_op)];
if coord.uses_operator_fast_path() {
ops.push(penalty_coord_to_operator(coord.clone(), scale));
}
ops
},
dim_hint: coord.dim(),
}))
}
None => {
if coord.uses_operator_fast_path() {
DriftDerivResult::Operator(Arc::new(CompositeHyperOperator {
dense: None,
operators: vec![penalty_coord_to_operator(coord.clone(), scale)],
dim_hint: coord.dim(),
}))
} else {
DriftDerivResult::Dense(coord.scaled_dense_matrix(scale))
}
}
}
}
fn hyper_coord_drift_operators(drift: &HyperCoordDrift) -> Vec<Arc<dyn HyperOperator>> {
let mut operators: Vec<Arc<dyn HyperOperator>> = Vec::new();
if let Some(block_local) = drift.block_local.as_ref() {
operators.push(Arc::new(block_local.clone()));
}
if let Some(operator) = drift.operator.as_ref() {
operators.push(Arc::clone(operator));
}
operators
}
fn hyper_coord_drift_operator_arc(
drift: &HyperCoordDrift,
dim_hint: usize,
) -> Option<Arc<dyn HyperOperator>> {
let mut operators = hyper_coord_drift_operators(drift);
if operators.is_empty() {
return None;
}
if drift.dense.is_none() && operators.len() == 1 {
return Some(operators.pop().expect("single operator drift"));
}
Some(Arc::new(CompositeHyperOperator {
dense: drift.dense.clone(),
operators,
dim_hint,
}))
}
fn drift_parts_into_result(
dense: Option<Array2<f64>>,
mut operators: Vec<Arc<dyn HyperOperator>>,
dim_hint: usize,
) -> DriftDerivResult {
if operators.is_empty() {
DriftDerivResult::Dense(dense.unwrap_or_else(|| Array2::<f64>::zeros((dim_hint, dim_hint))))
} else if dense.is_none() && operators.len() == 1 {
DriftDerivResult::Operator(operators.pop().expect("single operator drift"))
} else {
DriftDerivResult::Operator(Arc::new(CompositeHyperOperator {
dense,
operators,
dim_hint,
}))
}
}
fn hyper_coord_total_drift_parts(
drift: &HyperCoordDrift,
correction: Option<&DriftDerivResult>,
) -> (Option<Array2<f64>>, Vec<Arc<dyn HyperOperator>>) {
let mut dense = drift.dense.clone();
let mut operators = hyper_coord_drift_operators(drift);
if let Some(correction) = correction {
match correction {
DriftDerivResult::Dense(matrix) => {
if let Some(existing) = dense.as_mut() {
*existing += matrix;
} else {
dense = Some(matrix.clone());
}
}
DriftDerivResult::Operator(operator) => operators.push(Arc::clone(operator)),
}
}
(dense, operators)
}
fn hyper_coord_total_drift_result(
drift: &HyperCoordDrift,
correction: Option<&DriftDerivResult>,
dim_hint: usize,
) -> DriftDerivResult {
let (dense, operators) = hyper_coord_total_drift_parts(drift, correction);
drift_parts_into_result(dense, operators, dim_hint)
}
#[inline]
fn efs_q_eff(a_i: f64, dispersion: &DispersionHandling, dp_cgrad: f64, phi: f64) -> f64 {
match dispersion {
DispersionHandling::ProfiledGaussian => 2.0 * dp_cgrad * a_i / phi,
DispersionHandling::Fixed { .. } => 2.0 * a_i,
}
}
fn gamma_precision_rate_for_rho(prior: &crate::types::RhoPrior, idx: usize) -> Option<f64> {
match prior {
crate::types::RhoPrior::GammaPrecision { rate, .. } => Some(*rate),
crate::types::RhoPrior::Independent(priors) => {
priors.get(idx).and_then(|prior| match prior {
crate::types::RhoPrior::GammaPrecision { rate, .. } => Some(*rate),
_ => None,
})
}
_ => None,
}
}
#[inline]
fn efs_q_eff_with_gamma_rate(
base_q_eff: f64,
lambda: f64,
prior: &crate::types::RhoPrior,
idx: usize,
) -> f64 {
match gamma_precision_rate_for_rho(prior, idx) {
Some(rate) if rate.is_finite() && rate > 0.0 => base_q_eff + 2.0 * rate * lambda,
_ => base_q_eff,
}
}
#[inline]
fn efs_log_step_from_grad(q_eff: f64, g_full: f64) -> Option<f64> {
if !q_eff.is_finite() || q_eff <= 0.0 || !g_full.is_finite() {
return None;
}
let ratio = 1.0 - 2.0 * g_full / q_eff;
if ratio > 0.0 {
Some(ratio.ln().clamp(-EFS_MAX_STEP, EFS_MAX_STEP))
} else {
Some(-EFS_MAX_STEP)
}
}
#[inline]
fn efs_profiling(solution: &InnerSolution<'_>) -> (f64, f64) {
match &solution.dispersion {
DispersionHandling::ProfiledGaussian => {
let dp_raw = -2.0 * solution.log_likelihood + solution.penalty_quadratic;
let (dp_c, dp_cgrad, _) = smooth_floor_dp(dp_raw);
let denom = (solution.n_observations as f64 - solution.nullspace_dim).max(DENOM_RIDGE);
(dp_c / denom, dp_cgrad)
}
DispersionHandling::Fixed { phi, .. } => (*phi, 0.0),
}
}
fn trace_hinv_cached_drift_cross(
hop: &dyn HessianOperator,
left_dense: Option<&Array2<f64>>,
left_op: Option<&dyn HyperOperator>,
right_dense: Option<&Array2<f64>>,
right_op: Option<&dyn HyperOperator>,
) -> f64 {
match (left_op, right_op) {
(Some(left), Some(right)) => hop.trace_hinv_operator_cross(left, right),
(Some(left), None) => hop.trace_hinv_matrix_operator_cross(
right_dense.expect("right dense drift should be cached"),
left,
),
(None, Some(right)) => hop.trace_hinv_matrix_operator_cross(
left_dense.expect("left dense drift should be cached"),
right,
),
(None, None) => hop.trace_hinv_product_cross(
left_dense.expect("left dense drift should be cached"),
right_dense.expect("right dense drift should be cached"),
),
}
}
#[inline]
fn outer_gradient_entry(
a_i: f64,
trace_logdet_i: f64,
ld_s_i: f64,
dispersion: &DispersionHandling,
dp_cgrad: f64,
profiled_scale: f64,
incl_logdet_h: bool,
incl_logdet_s: bool,
) -> f64 {
let penalty_term = match dispersion {
DispersionHandling::ProfiledGaussian => dp_cgrad * a_i / profiled_scale,
DispersionHandling::Fixed { .. } => a_i,
};
let trace_term = if incl_logdet_h {
0.5 * trace_logdet_i
} else {
0.0
};
let det_term = if incl_logdet_s { 0.5 * ld_s_i } else { 0.0 };
penalty_term + trace_term - det_term
}
#[inline]
fn outer_hessian_entry(
a_i: f64,
a_j: f64,
g_i_dot_v_j: f64,
pair_a: f64,
cross_trace: f64,
h2_trace: f64,
pair_ld_s: f64,
profiled_phi: f64,
profiled_nu: f64,
profiled_dp_cgrad: f64,
profiled_dp_cgrad2: f64,
is_profiled: bool,
incl_logdet_h: bool,
incl_logdet_s: bool,
) -> f64 {
let q_raw = pair_a - g_i_dot_v_j;
let q = if is_profiled {
profiled_dp_cgrad * q_raw / profiled_phi
+ 2.0
* (profiled_dp_cgrad2 * profiled_nu * profiled_phi
- profiled_dp_cgrad * profiled_dp_cgrad)
* a_i
* a_j
/ (profiled_nu * profiled_phi * profiled_phi)
} else {
q_raw
};
let l = if incl_logdet_h {
0.5 * (cross_trace + h2_trace)
} else {
0.0
};
let p = if incl_logdet_s { -0.5 * pair_ld_s } else { 0.0 };
q + l + p
}
fn compute_active_constraint_tangent_basis(a_act: &Array2<f64>) -> Option<Array2<f64>> {
let k_act = a_act.nrows();
let p = a_act.ncols();
if k_act == 0 {
return None;
}
let ata = a_act.t().dot(a_act);
let (evals, evecs) = ata.eigh(faer::Side::Lower).ok()?;
let evals_slice = evals.as_slice()?;
let threshold = positive_eigenvalue_threshold(evals_slice);
let null_count = evals_slice.iter().filter(|&&s| s <= threshold).count();
if null_count == 0 || null_count == p {
return None;
}
Some(evecs.slice(ndarray::s![.., 0..null_count]).to_owned())
}
fn materialize_penalty_coord_dense(coord: &PenaltyCoordinate, p: usize) -> Array2<f64> {
let out = coord.scaled_dense_matrix(1.0);
assert_eq!(out.nrows(), p, "penalty coord dim mismatch");
assert_eq!(out.ncols(), p, "penalty coord dim mismatch");
out
}
fn assemble_h_raw_dense(op: &DenseSpectralOperator) -> Array2<f64> {
let p = op.n_dim;
let epsilon = f64::EPSILON.sqrt() * (p as f64).max(1.0);
let eps_sq = epsilon * epsilon;
if p == 0 {
return Array2::<f64>::zeros((0, 0));
}
let mut vs = op.eigenvectors.clone();
for j in 0..p {
let sigma = if op.active_mask[j] {
let r = op.reg_eigenvalues[j];
if r == 0.0 { 0.0 } else { r - eps_sq / r }
} else {
0.0
};
if sigma != 1.0 {
let mut col = vs.column_mut(j);
if sigma == 0.0 {
col.fill(0.0);
} else {
col.mapv_inplace(|v| v * sigma);
}
}
}
crate::faer_ndarray::fast_abt(&vs, &op.eigenvectors)
}
struct TangentProjectedHessianOperator {
z: Array2<f64>,
h_t_op: DenseSpectralOperator,
}
impl HessianOperator for TangentProjectedHessianOperator {
fn active_rank(&self) -> usize {
self.h_t_op.active_rank()
}
fn dim(&self) -> usize {
self.z.nrows()
}
fn logdet(&self) -> f64 {
self.h_t_op.logdet()
}
fn solve(&self, rhs: &Array1<f64>) -> Array1<f64> {
let r_t = self.z.t().dot(rhs);
let q_t = self.h_t_op.solve(&r_t);
self.z.dot(&q_t)
}
fn solve_multi(&self, rhs: &Array2<f64>) -> Array2<f64> {
let r_t = self.z.t().dot(rhs);
let q_t = self.h_t_op.solve_multi(&r_t);
self.z.dot(&q_t)
}
fn trace_hinv_product(&self, a: &Array2<f64>) -> f64 {
let zaz = self.z.t().dot(a).dot(&self.z);
self.h_t_op.trace_hinv_product(&zaz)
}
fn trace_logdet_gradient(&self, a: &Array2<f64>) -> f64 {
let zaz = self.z.t().dot(a).dot(&self.z);
self.h_t_op.trace_logdet_gradient(&zaz)
}
fn is_dense(&self) -> bool {
self.h_t_op.is_dense()
}
fn logdet_traces_match_hinv_kernel(&self) -> bool {
self.h_t_op.logdet_traces_match_hinv_kernel()
}
}
fn tangent_penalty_logdet(
z: &Array2<f64>,
penalty_coords: &[PenaltyCoordinate],
lambdas: &[f64],
p: usize,
) -> Result<PenaltyLogdetDerivs, String> {
let m = z.ncols();
let k = lambdas.len();
let zsz: Vec<Array2<f64>> = penalty_coords
.iter()
.map(|c| {
let s_k_full = materialize_penalty_coord_dense(c, p);
z.t().dot(&s_k_full).dot(z)
})
.collect();
let mut s_t = Array2::<f64>::zeros((m, m));
for k_idx in 0..k {
s_t.scaled_add(lambdas[k_idx], &zsz[k_idx]);
}
let (evals, evecs) = s_t
.eigh(faer::Side::Lower)
.map_err(|e| format!("tangent S eigendecomposition failed: {e}"))?;
let evals_slice = evals.as_slice().ok_or_else(|| {
"tangent S eigendecomposition returned non-contiguous eigenvalues".to_string()
})?;
let threshold = positive_eigenvalue_threshold(evals_slice);
let value = exact_pseudo_logdet(evals_slice, threshold);
let mut s_t_plus = Array2::<f64>::zeros((m, m));
for j in 0..m {
if evals[j] > threshold {
let inv = 1.0 / evals[j];
for r in 0..m {
let factor = evecs[[r, j]] * inv;
for c in 0..m {
s_t_plus[[r, c]] += factor * evecs[[c, j]];
}
}
}
}
let mut first = Array1::<f64>::zeros(k);
for k_idx in 0..k {
first[k_idx] = lambdas[k_idx] * trace_matrix_product(&s_t_plus, &zsz[k_idx]);
}
let mut second = Array2::<f64>::zeros((k, k));
for k_idx in 0..k {
second[[k_idx, k_idx]] += first[k_idx];
}
let s_plus_zsz: Vec<Array2<f64>> = zsz.iter().map(|m_k| s_t_plus.dot(m_k)).collect();
for k_idx in 0..k {
for l_idx in 0..=k_idx {
let cross = trace_matrix_product(&s_plus_zsz[k_idx], &s_plus_zsz[l_idx]);
let entry = -lambdas[k_idx] * lambdas[l_idx] * cross;
second[[k_idx, l_idx]] += entry;
if l_idx != k_idx {
second[[l_idx, k_idx]] += entry;
}
}
}
Ok(PenaltyLogdetDerivs {
value,
first,
second: Some(second),
})
}
struct BorrowedDerivProvider<'a>(&'a dyn HessianDerivativeProvider);
impl<'a> HessianDerivativeProvider for BorrowedDerivProvider<'a> {
fn hessian_derivative_correction(
&self,
v: &Array1<f64>,
) -> Result<Option<Array2<f64>>, String> {
self.0.hessian_derivative_correction(v)
}
fn hessian_derivative_correction_result(
&self,
v: &Array1<f64>,
) -> Result<Option<DriftDerivResult>, String> {
self.0.hessian_derivative_correction_result(v)
}
fn hessian_derivative_corrections_result(
&self,
vs: &[Array1<f64>],
) -> Result<Vec<Option<DriftDerivResult>>, String> {
self.0.hessian_derivative_corrections_result(vs)
}
fn has_batched_hessian_derivative_corrections(&self) -> bool {
self.0.has_batched_hessian_derivative_corrections()
}
fn hessian_second_derivative_correction(
&self,
v_k: &Array1<f64>,
v_l: &Array1<f64>,
u_kl: &Array1<f64>,
) -> Result<Option<Array2<f64>>, String> {
self.0.hessian_second_derivative_correction(v_k, v_l, u_kl)
}
fn hessian_second_derivative_correction_result(
&self,
v_k: &Array1<f64>,
v_l: &Array1<f64>,
u_kl: &Array1<f64>,
) -> Result<Option<DriftDerivResult>, String> {
self.0
.hessian_second_derivative_correction_result(v_k, v_l, u_kl)
}
fn hessian_second_derivative_corrections_result(
&self,
triples: &[(Array1<f64>, Array1<f64>, Array1<f64>)],
) -> Result<Vec<Option<DriftDerivResult>>, String> {
self.0.hessian_second_derivative_corrections_result(triples)
}
fn has_batched_hessian_second_derivative_corrections(&self) -> bool {
self.0.has_batched_hessian_second_derivative_corrections()
}
fn has_corrections(&self) -> bool {
self.0.has_corrections()
}
fn outer_hessian_derivative_kernel(&self) -> Option<OuterHessianDerivativeKernel> {
self.0.outer_hessian_derivative_kernel()
}
fn family_outer_hessian_operator(
&self,
) -> Option<Arc<dyn crate::solver::outer_strategy::OuterHessianOperator>> {
self.0.family_outer_hessian_operator()
}
fn scalar_glm_ingredients(&self) -> Option<ScalarGlmIngredients<'_>> {
self.0.scalar_glm_ingredients()
}
}
fn try_tangent_projected_evaluate(
solution: &InnerSolution<'_>,
rho: &[f64],
mode: EvalMode,
prior_cost_gradient: Option<(f64, Array1<f64>, Option<Array2<f64>>)>,
) -> Result<Option<RemlLamlResult>, String> {
let block = match solution.active_constraints.as_ref() {
Some(b) if b.a.nrows() > 0 => b,
_ => return Ok(None),
};
let p = solution.beta.len();
if block.a.ncols() != p {
return Err(format!(
"active_constraints.a has {} columns but β is {}-dim",
block.a.ncols(),
p
));
}
let z = match compute_active_constraint_tangent_basis(&block.a) {
Some(z) => z,
None => {
return Err(format!(
"active constraint matrix has rank {} on {}-dim space; \
tangent manifold is a single point ({{β̂}}), no outer \
derivative is defined",
block.a.nrows(),
p
));
}
};
let h_full = solution
.hessian_op
.assemble_h_dense_for_tangent_projection()?;
let h_t = z.t().dot(&h_full).dot(&z);
let h_t_op = DenseSpectralOperator::from_symmetric(&h_t)
.map_err(|e| format!("tangent H eigendecomposition failed: {e}"))?;
let lambdas: Vec<f64> = rho.iter().map(|&r| r.exp()).collect();
let projected_logdet = tangent_penalty_logdet(&z, &solution.penalty_coords, &lambdas, p)?;
let projected_kkt = solution.kkt_residual.clone();
let m_tangent = z.ncols();
let wrapper = TangentProjectedHessianOperator {
z: z.clone(),
h_t_op,
};
let projected_hlogdet_correction = if p == 0 {
0.0
} else {
solution.hessian_logdet_correction * (m_tangent as f64 / p as f64)
};
let projected_firth = solution
.firth
.as_ref()
.map(|term| match term.operator_arc() {
Some(op_arc) => {
let projected_value = op_arc.jeffreys_logdet_projected(z.view());
ExactJeffreysTerm::with_projected_value(op_arc, projected_value)
}
None => term.clone(),
});
if mode == EvalMode::ValueGradientHessian
&& !solution.ext_coords.is_empty()
&& (solution.ext_coord_pair_fn.is_some() || solution.rho_ext_pair_fn.is_some())
{
return Err(
"active constraints + ext_coords + mode=ValueGradientHessian not yet supported; \
fall back to ValueAndGradient. The ext-coord pair callbacks return p-space \
second-drift objects that the tangent hessian wrapper does not re-project."
.to_string(),
);
}
let projected = InnerSolution {
log_likelihood: solution.log_likelihood,
penalty_quadratic: solution.penalty_quadratic,
hessian_op: Arc::new(wrapper),
beta: solution.beta.clone(),
penalty_coords: solution.penalty_coords.clone(),
penalty_logdet: projected_logdet,
deriv_provider: Box::new(BorrowedDerivProvider(solution.deriv_provider.as_ref())),
tk_correction: solution.tk_correction,
tk_gradient: solution.tk_gradient.clone(),
firth: projected_firth,
hessian_logdet_correction: projected_hlogdet_correction,
penalty_subspace_trace: None,
rho_curvature_scale: solution.rho_curvature_scale,
rho_prior: solution.rho_prior.clone(),
n_observations: solution.n_observations,
nullspace_dim: solution.nullspace_dim,
gaussian_weight_log_sum_half: solution.gaussian_weight_log_sum_half,
dispersion: solution.dispersion.clone(),
ext_coords: solution.ext_coords.clone(),
ext_coord_pair_fn: None,
rho_ext_pair_fn: None,
contracted_psi_second_order: None,
fixed_drift_deriv: None,
barrier_config: solution.barrier_config.clone(),
kkt_residual: projected_kkt,
active_constraints: None,
stochastic_trace_state: solution.stochastic_trace_state.clone(),
};
let result = reml_laml_evaluate(&projected, rho, mode, prior_cost_gradient)?;
Ok(Some(result))
}