use super::*;
pub(crate) fn as_implicit(op: &dyn HyperOperator) -> Option<&ImplicitHyperOperator> {
op.as_any().downcast_ref::<ImplicitHyperOperator>()
}
pub(crate) fn as_composite(op: &dyn HyperOperator) -> Option<&CompositeHyperOperator> {
op.as_any().downcast_ref::<CompositeHyperOperator>()
}
pub(crate) fn as_weighted(op: &dyn HyperOperator) -> Option<&WeightedHyperOperator> {
op.as_any().downcast_ref::<WeightedHyperOperator>()
}
pub(crate) trait DriftDerivTraceExt {
fn trace_logdet(&self, hop: &dyn HessianOperator) -> f64;
fn trace_logdet_hessian_cross(&self, rhs: &Self, hop: &dyn HessianOperator) -> f64;
}
impl DriftDerivTraceExt for DriftDerivResult {
fn trace_logdet(&self, hop: &dyn HessianOperator) -> f64 {
match self {
Self::Dense(matrix) => hop.trace_logdet_gradient(matrix),
Self::Operator(operator) => hop.trace_logdet_operator(operator.as_ref()),
}
}
fn trace_logdet_hessian_cross(&self, rhs: &Self, hop: &dyn HessianOperator) -> f64 {
match (self, rhs) {
(Self::Dense(left), Self::Dense(right)) => hop.trace_logdet_hessian_cross(left, right),
(Self::Dense(left), Self::Operator(right)) => {
hop.trace_logdet_hessian_cross_matrix_operator(left, right.as_ref())
}
(Self::Operator(left), Self::Dense(right)) => {
hop.trace_logdet_hessian_cross_matrix_operator(right, left.as_ref())
}
(Self::Operator(left), Self::Operator(right)) => {
hop.trace_logdet_hessian_cross_operator(left.as_ref(), right.as_ref())
}
}
}
}
#[derive(Clone)]
pub struct CompositeHyperOperator {
pub dense: Option<Array2<f64>>,
pub operators: Vec<Arc<dyn HyperOperator>>,
pub dim_hint: usize,
}
pub(crate) fn composite_trace_implicit_batched(
operators: &[Arc<dyn HyperOperator>],
factor: &Array2<f64>,
cache: Option<&ProjectedFactorCache>,
) -> f64 {
let mut trace = 0.0;
let mut group_starts: Vec<Vec<usize>> = Vec::new();
let mut handled = vec![false; operators.len()];
for (i, op) in operators.iter().enumerate() {
if handled[i] {
continue;
}
let Some(impl_i) = as_implicit(op.as_ref()) else {
continue;
};
let mut group = vec![i];
handled[i] = true;
for j in (i + 1)..operators.len() {
if handled[j] {
continue;
}
if let Some(impl_j) = as_implicit(operators[j].as_ref())
&& Arc::ptr_eq(&impl_i.implicit_deriv, &impl_j.implicit_deriv)
&& Arc::ptr_eq(&impl_i.x_design, &impl_j.x_design)
&& Arc::ptr_eq(impl_i.w_diag.as_arc(), impl_j.w_diag.as_arc())
&& impl_i.p == impl_j.p
{
group.push(j);
handled[j] = true;
}
}
group_starts.push(group);
}
for group in &group_starts {
if group.len() >= 2 {
let lead = as_implicit(operators[group[0]].as_ref()).unwrap();
let xf = match cache {
Some(c) => lead.cached_xf(factor, c),
None => Arc::new(lead.compute_xf(factor)),
};
let axes: Vec<(usize, &Array2<f64>, Option<&Array1<f64>>)> = group
.iter()
.map(|&k| {
let op = as_implicit(operators[k].as_ref()).unwrap();
(op.axis, &op.s_psi, op.c_x_psi_beta.as_deref())
})
.collect();
let values = lead.trace_projected_factor_all_axes_with_xf(factor, xf.view(), &axes);
trace += values.iter().sum::<f64>();
} else {
let op = &operators[group[0]];
trace += match cache {
Some(c) => op.trace_projected_factor_cached(factor, c),
None => op.trace_projected_factor(factor),
};
}
}
for (i, op) in operators.iter().enumerate() {
if handled[i] {
continue;
}
trace += match cache {
Some(c) => op.trace_projected_factor_cached(factor, c),
None => op.trace_projected_factor(factor),
};
}
trace
}
pub(crate) fn trace_projected_factors_batched(
operators: &[Arc<dyn HyperOperator>],
factor: &Array2<f64>,
cache: &ProjectedFactorCache,
) -> Vec<f64> {
let mut out = vec![0.0; operators.len()];
let mut handled = vec![false; operators.len()];
for i in 0..operators.len() {
if handled[i] {
continue;
}
let Some(impl_i) = as_implicit(operators[i].as_ref()) else {
out[i] = operators[i].trace_projected_factor_cached(factor, cache);
handled[i] = true;
continue;
};
let mut group = vec![i];
handled[i] = true;
for j in (i + 1)..operators.len() {
if handled[j] {
continue;
}
if let Some(impl_j) = as_implicit(operators[j].as_ref())
&& Arc::ptr_eq(&impl_i.implicit_deriv, &impl_j.implicit_deriv)
&& Arc::ptr_eq(&impl_i.x_design, &impl_j.x_design)
&& Arc::ptr_eq(impl_i.w_diag.as_arc(), impl_j.w_diag.as_arc())
&& impl_i.p == impl_j.p
{
group.push(j);
handled[j] = true;
}
}
if group.len() >= 2 {
let xf = impl_i.cached_xf(factor, cache);
let axes: Vec<(usize, &Array2<f64>, Option<&Array1<f64>>)> = group
.iter()
.map(|&idx| {
let op = as_implicit(operators[idx].as_ref()).unwrap();
(op.axis, &op.s_psi, op.c_x_psi_beta.as_deref())
})
.collect();
let values = impl_i.trace_projected_factor_all_axes_with_xf(factor, xf.view(), &axes);
for (&idx, value) in group.iter().zip(values) {
out[idx] = value;
}
} else {
out[i] = operators[i].trace_projected_factor_cached(factor, cache);
}
}
out
}
pub(crate) fn collect_projected_trace_terms<'a>(
out_idx: usize,
weight: f64,
op: &'a dyn HyperOperator,
factor: &Array2<f64>,
dense_acc: &mut [f64],
terms: &mut Vec<(usize, f64, &'a dyn HyperOperator)>,
) {
if weight == 0.0 {
return;
}
if let Some(composite) = as_composite(op) {
if let Some(dense) = composite.dense.as_ref() {
dense_acc[out_idx] += weight * dense_trace_projected_factor(dense, factor);
}
for inner in &composite.operators {
collect_projected_trace_terms(
out_idx,
weight,
inner.as_ref(),
factor,
dense_acc,
terms,
);
}
} else if let Some(weighted) = as_weighted(op) {
for (term_weight, inner) in &weighted.terms {
collect_projected_trace_terms(
out_idx,
weight * *term_weight,
inner.as_ref(),
factor,
dense_acc,
terms,
);
}
} else {
terms.push((out_idx, weight, op));
}
}
pub(crate) fn collect_projected_matrix_terms<'a>(
out_idx: usize,
weight: f64,
op: &'a dyn HyperOperator,
factor: &Array2<f64>,
dense_acc: &mut [Array2<f64>],
terms: &mut Vec<(usize, f64, &'a dyn HyperOperator)>,
) {
if weight == 0.0 {
return;
}
if let Some(composite) = as_composite(op) {
if let Some(dense) = composite.dense.as_ref() {
dense_acc[out_idx].scaled_add(weight, &dense_projected_matrix(dense, factor));
}
for inner in &composite.operators {
collect_projected_matrix_terms(
out_idx,
weight,
inner.as_ref(),
factor,
dense_acc,
terms,
);
}
} else if let Some(weighted) = as_weighted(op) {
for (term_weight, inner) in &weighted.terms {
collect_projected_matrix_terms(
out_idx,
weight * *term_weight,
inner.as_ref(),
factor,
dense_acc,
terms,
);
}
} else {
terms.push((out_idx, weight, op));
}
}
pub(crate) fn trace_projected_operator_terms_batched(
n_out: usize,
terms: &[(usize, f64, &dyn HyperOperator)],
factor: &Array2<f64>,
cache: &ProjectedFactorCache,
) -> Vec<f64> {
let mut out = vec![0.0_f64; n_out];
let mut handled = vec![false; terms.len()];
for i in 0..terms.len() {
if handled[i] {
continue;
}
let Some(impl_i) = as_implicit(terms[i].2) else {
continue;
};
let mut group = vec![i];
handled[i] = true;
for j in (i + 1)..terms.len() {
if handled[j] {
continue;
}
if let Some(impl_j) = as_implicit(terms[j].2)
&& Arc::ptr_eq(&impl_i.implicit_deriv, &impl_j.implicit_deriv)
&& Arc::ptr_eq(&impl_i.x_design, &impl_j.x_design)
&& Arc::ptr_eq(impl_i.w_diag.as_arc(), impl_j.w_diag.as_arc())
&& impl_i.p == impl_j.p
{
group.push(j);
handled[j] = true;
}
}
let lead = as_implicit(terms[group[0]].2).unwrap();
let xf = lead.cached_xf(factor, cache);
let axes: Vec<(usize, &Array2<f64>, Option<&Array1<f64>>)> = group
.iter()
.map(|&term_idx| {
let op = as_implicit(terms[term_idx].2).unwrap();
(op.axis, &op.s_psi, op.c_x_psi_beta.as_deref())
})
.collect();
let values = lead.trace_projected_factor_all_axes_with_xf(factor, xf.view(), &axes);
for (&term_idx, value) in group.iter().zip(values.iter()) {
let (out_idx, weight, _) = terms[term_idx];
out[out_idx] += weight * *value;
}
}
for (i, (out_idx, weight, op)) in terms.iter().enumerate() {
if handled[i] {
continue;
}
out[*out_idx] += *weight * op.trace_projected_factor_cached(factor, cache);
}
out
}
pub(crate) fn projected_operator_terms_batched(
n_out: usize,
terms: &[(usize, f64, &dyn HyperOperator)],
factor: &Array2<f64>,
cache: &ProjectedFactorCache,
) -> Vec<Array2<f64>> {
let rank = factor.ncols();
let mut out: Vec<Array2<f64>> = (0..n_out)
.map(|_| Array2::<f64>::zeros((rank, rank)))
.collect();
for (out_idx, weight, op) in terms.iter() {
let projected = op.projected_matrix_cached(factor, cache);
out[*out_idx].scaled_add(*weight, &projected);
}
out
}
pub(crate) fn project_hyper_operators_batched(
n_out: usize,
terms: &[(usize, f64, &dyn HyperOperator)],
factor: &Array2<f64>,
cache: &ProjectedFactorCache,
) -> Vec<Array2<f64>> {
projected_operator_terms_batched(n_out, terms, factor, cache)
}
pub(crate) fn trace_logdet_drifts_projected_factor_batched(
drifts: &[DriftDerivResult],
factor: &Array2<f64>,
cache: &ProjectedFactorCache,
) -> Vec<f64> {
let mut out = vec![0.0_f64; drifts.len()];
let mut terms: Vec<(usize, f64, &dyn HyperOperator)> = Vec::new();
for (idx, drift) in drifts.iter().enumerate() {
match drift {
DriftDerivResult::Dense(matrix) => {
out[idx] += dense_trace_projected_factor(matrix, factor);
}
DriftDerivResult::Operator(op) => {
collect_projected_trace_terms(idx, 1.0, op.as_ref(), factor, &mut out, &mut terms);
}
}
}
let batched = trace_projected_operator_terms_batched(drifts.len(), &terms, factor, cache);
for (dst, value) in out.iter_mut().zip(batched) {
*dst += value;
}
out
}
pub(crate) fn dense_spectral_trace_logdet_drifts_batched(
ds: &DenseSpectralOperator,
drifts: &[DriftDerivResult],
) -> Vec<f64> {
trace_logdet_drifts_projected_factor_batched(drifts, &ds.g_factor, &ds.projected_factor_cache)
}
pub(crate) fn penalty_subspace_trace_factor(kernel: &PenaltySubspaceTrace) -> Array2<f64> {
let (evals, evecs) = kernel
.h_proj_inverse
.eigh(faer::Side::Lower)
.expect("PenaltySubspaceTrace kernel factor eigendecomposition failed");
let r = evals.len();
let mut root = evecs.clone();
for col in 0..r {
let scale = evals[col].max(0.0).sqrt();
for row in 0..r {
root[[row, col]] *= scale;
}
}
crate::faer_ndarray::fast_ab(&kernel.u_s, &root)
}
pub(crate) fn penalty_subspace_trace_drifts_batched(
kernel: &PenaltySubspaceTrace,
drifts: &[DriftDerivResult],
) -> Vec<f64> {
let factor = penalty_subspace_trace_factor(kernel);
let cache = ProjectedFactorCache::default();
trace_logdet_drifts_projected_factor_batched(drifts, &factor, &cache)
}
pub(crate) fn penalty_subspace_reduce_drifts_batched(
kernel: &PenaltySubspaceTrace,
drifts: &[DriftDerivResult],
) -> Vec<Array2<f64>> {
drifts
.iter()
.map(|drift| match drift {
DriftDerivResult::Dense(matrix) => kernel.reduce(matrix),
DriftDerivResult::Operator(op) => kernel.reduce_operator(op.as_ref()),
})
.collect()
}
pub(crate) fn dense_spectral_trace_logdet_operators_batched(
ds: &DenseSpectralOperator,
operators: &[Arc<dyn HyperOperator>],
) -> Vec<f64> {
if operators.is_empty() {
return Vec::new();
}
if log::log_enabled!(log::Level::Info) {
let start = std::time::Instant::now();
let out =
trace_projected_factors_batched(operators, &ds.g_factor, &ds.projected_factor_cache);
let implicit_count = operators.iter().filter(|op| op.is_implicit()).count();
dense_spectral_stage_log(
&format!(
"DenseSpectralOperator::trace_logdet_operators_batched dim={} rank={} ops={} implicit_ops={}",
ds.n_dim,
ds.g_factor.ncols(),
operators.len(),
implicit_count,
),
start.elapsed().as_secs_f64(),
);
out
} else {
trace_projected_factors_batched(operators, &ds.g_factor, &ds.projected_factor_cache)
}
}
impl HyperOperator for CompositeHyperOperator {
fn as_any(&self) -> &(dyn std::any::Any + 'static) {
self
}
fn dim(&self) -> usize {
self.dim_hint
}
fn mul_vec(&self, v: &Array1<f64>) -> Array1<f64> {
let mut out = Array1::<f64>::zeros(v.len());
self.mul_vec_into(v.view(), out.view_mut());
out
}
fn mul_vec_view(&self, v: ArrayView1<'_, f64>) -> Array1<f64> {
let mut out = Array1::<f64>::zeros(v.len());
self.mul_vec_into(v, out.view_mut());
out
}
fn mul_vec_into(&self, v: ArrayView1<'_, f64>, mut out: ArrayViewMut1<'_, f64>) {
if self.dense.is_none() && self.operators.len() == 1 {
self.operators[0].mul_vec_into(v, out);
return;
}
out.fill(0.0);
if let Some(dense) = self.dense.as_ref() {
dense_matvec_into(dense, v, out.view_mut());
}
for op in &self.operators {
op.scaled_add_mul_vec(v, 1.0, out.view_mut());
}
}
fn mul_basis_columns_into(&self, start: usize, mut out: ArrayViewMut2<'_, f64>) {
if self.dense.is_none() && self.operators.len() == 1 {
self.operators[0].mul_basis_columns_into(start, out);
return;
}
out.fill(0.0);
let cols = out.ncols();
let end = start + cols;
if let Some(dense) = self.dense.as_ref() {
out += &dense.slice(ndarray::s![.., start..end]);
}
let mut work = Array2::<f64>::zeros((out.nrows(), cols));
for op in &self.operators {
op.mul_basis_columns_into(start, work.view_mut());
out += &work;
}
}
fn scaled_add_mul_vec(
&self,
v: ArrayView1<'_, f64>,
scale: f64,
mut out: ArrayViewMut1<'_, f64>,
) {
if scale == 0.0 {
return;
}
if self.dense.is_none() && self.operators.len() == 1 {
self.operators[0].scaled_add_mul_vec(v, scale, out);
return;
}
if let Some(dense) = self.dense.as_ref() {
dense_matvec_scaled_add_into(dense, v, scale, out.view_mut());
}
for op in &self.operators {
op.scaled_add_mul_vec(v, scale, out.view_mut());
}
}
fn mul_mat(&self, factor: &Array2<f64>) -> Array2<f64> {
if self.dense.is_none() && self.operators.len() == 1 {
return self.operators[0].mul_mat(factor);
}
let p = factor.nrows();
let k = factor.ncols();
let mut out = Array2::<f64>::zeros((p, k));
if let Some(dense) = self.dense.as_ref() {
out += &dense.dot(factor);
}
for op in &self.operators {
out += &op.mul_mat(factor);
}
out
}
fn trace_projected_factor(&self, factor: &Array2<f64>) -> f64 {
if self.dense.is_none() && self.operators.len() == 1 {
return self.operators[0].trace_projected_factor(factor);
}
let mut trace = 0.0;
if let Some(dense) = self.dense.as_ref() {
let dense_factor = dense.dot(factor);
trace += factor
.iter()
.zip(dense_factor.iter())
.map(|(&f, &bf)| f * bf)
.sum::<f64>();
}
trace += composite_trace_implicit_batched(&self.operators, factor, None);
trace
}
fn trace_projected_factor_cached(
&self,
factor: &Array2<f64>,
cache: &ProjectedFactorCache,
) -> f64 {
if self.dense.is_none() && self.operators.len() == 1 {
return self.operators[0].trace_projected_factor_cached(factor, cache);
}
let mut trace = 0.0;
if let Some(dense) = self.dense.as_ref() {
let dense_factor = dense.dot(factor);
trace += factor
.iter()
.zip(dense_factor.iter())
.map(|(&f, &bf)| f * bf)
.sum::<f64>();
}
trace += composite_trace_implicit_batched(&self.operators, factor, Some(cache));
trace
}
fn projected_matrix(&self, factor: &Array2<f64>) -> Array2<f64> {
if self.dense.is_none() && self.operators.len() == 1 {
return self.operators[0].projected_matrix(factor);
}
let rank = factor.ncols();
let mut projected = Array2::<f64>::zeros((rank, rank));
if let Some(dense) = self.dense.as_ref() {
let mf = crate::faer_ndarray::fast_ab(dense, factor);
projected += &crate::faer_ndarray::fast_atb(factor, &mf);
}
for op in &self.operators {
projected += &op.projected_matrix(factor);
}
projected
}
fn projected_matrix_cached(
&self,
factor: &Array2<f64>,
cache: &ProjectedFactorCache,
) -> Array2<f64> {
if self.dense.is_none() && self.operators.len() == 1 {
return self.operators[0].projected_matrix_cached(factor, cache);
}
let rank = factor.ncols();
let mut projected = Array2::<f64>::zeros((rank, rank));
if let Some(dense) = self.dense.as_ref() {
let mf = crate::faer_ndarray::fast_ab(dense, factor);
projected += &crate::faer_ndarray::fast_atb(factor, &mf);
}
for op in &self.operators {
projected += &op.projected_matrix_cached(factor, cache);
}
projected
}
fn bilinear(&self, v: &Array1<f64>, u: &Array1<f64>) -> f64 {
let mut total = 0.0;
if let Some(dense) = self.dense.as_ref() {
total += dense_bilinear(dense, v.view(), u.view());
}
for op in &self.operators {
total += op.bilinear(v, u);
}
total
}
fn bilinear_view(&self, v: ArrayView1<'_, f64>, u: ArrayView1<'_, f64>) -> f64 {
let mut total = 0.0;
if let Some(dense) = self.dense.as_ref() {
total += dense_bilinear(dense, v, u);
}
for op in &self.operators {
total += op.bilinear_view(v, u);
}
total
}
fn to_dense(&self) -> Array2<f64> {
let mut out = self
.dense
.clone()
.unwrap_or_else(|| Array2::<f64>::zeros((self.dim_hint, self.dim_hint)));
for op in &self.operators {
out += &op.to_dense();
}
out
}
fn is_implicit(&self) -> bool {
self.operators.iter().any(|op| op.is_implicit())
}
}
mod implicit_matvec_scratch {
use std::cell::RefCell;
pub(super) struct Scratch {
pub x_v: Vec<f64>,
pub n_work: Vec<f64>,
pub p_work: Vec<f64>,
}
impl Scratch {
pub(crate) const fn new() -> Self {
Self {
x_v: Vec::new(),
n_work: Vec::new(),
p_work: Vec::new(),
}
}
}
thread_local! {
static SCRATCH: RefCell<Scratch> = const { RefCell::new(Scratch::new()) };
}
pub(super) fn with<R>(f: impl FnOnce(&mut Scratch) -> R) -> R {
SCRATCH.with(|cell| f(&mut cell.borrow_mut()))
}
}
pub struct ImplicitHyperOperator {
pub implicit_deriv: std::sync::Arc<crate::terms::basis::ImplicitDesignPsiDerivative>,
pub axis: usize,
pub(crate) x_design: std::sync::Arc<DesignMatrix>,
pub(crate) w_diag: crate::matrix::SignedWeightsArc,
pub s_psi: Array2<f64>,
pub(crate) p: usize,
pub c_x_psi_beta: Option<std::sync::Arc<Array1<f64>>>,
}
impl HyperOperator for ImplicitHyperOperator {
fn dim(&self) -> usize {
self.p
}
fn mul_vec(&self, v: &Array1<f64>) -> Array1<f64> {
let mut out = Array1::<f64>::zeros(self.p);
self.mul_vec_into(v.view(), out.view_mut());
out
}
fn mul_vec_view(&self, v: ArrayView1<'_, f64>) -> Array1<f64> {
let mut out = Array1::<f64>::zeros(self.p);
self.mul_vec_into(v, out.view_mut());
out
}
fn mul_vec_into(&self, v: ArrayView1<'_, f64>, out: ArrayViewMut1<'_, f64>) {
assert_eq!(v.len(), self.p);
let n_obs = self.w_diag.len();
implicit_matvec_scratch::with(|s| {
s.x_v.clear();
s.x_v.resize(n_obs, 0.0);
s.n_work.clear();
s.n_work.resize(n_obs, 0.0);
s.p_work.clear();
s.p_work.resize(self.p, 0.0);
let mut x_v_view = ndarray::ArrayViewMut1::from(s.x_v.as_mut_slice());
let n_work_view = ndarray::ArrayViewMut1::from(s.n_work.as_mut_slice());
let p_work_view = ndarray::ArrayViewMut1::from(s.p_work.as_mut_slice());
design_matrix_apply_view_into(&self.x_design, v, x_v_view.view_mut());
self.matvec_with_shared_xz_into(x_v_view.view(), v, out, n_work_view, p_work_view);
});
}
fn mul_basis_columns_into(&self, start: usize, mut out: ArrayViewMut2<'_, f64>) {
let cols = out.ncols();
assert!(start + cols <= self.p);
let n_obs = self.w_diag.len();
let mut basis = Array1::<f64>::zeros(self.p);
let mut x_col = Array1::<f64>::zeros(n_obs);
let mut dx_col = Array1::<f64>::zeros(n_obs);
let mut weighted = Array1::<f64>::zeros(n_obs);
let mut term = Array1::<f64>::zeros(self.p);
for local_col in 0..cols {
let global_col = start + local_col;
let mut out_col = out.column_mut(local_col);
out_col.assign(&self.s_psi.column(global_col));
design_matrix_column_into(&self.x_design, global_col, x_col.view_mut());
Zip::from(weighted.view_mut())
.and(self.w_diag.view())
.and(x_col.view())
.par_for_each(|dst, &w, &x| *dst = w * x);
term.assign(
&self
.implicit_deriv
.transpose_mul(self.axis, &weighted.view())
.expect("radial scalar evaluation failed during implicit hyper transpose_mul"),
);
out_col += &term;
basis[global_col] = 1.0;
dx_col.assign(
&self
.implicit_deriv
.forward_mul(self.axis, &basis.view())
.expect("radial scalar evaluation failed during implicit hyper forward_mul"),
);
basis[global_col] = 0.0;
Zip::from(weighted.view_mut())
.and(self.w_diag.view())
.and(dx_col.view())
.par_for_each(|dst, &w, &dx| *dst = w * dx);
design_matrix_transpose_apply_view_into(
&self.x_design,
weighted.view(),
term.view_mut(),
);
out_col += &term;
self.accumulate_c_correction_xt_into(
x_col.view(),
weighted.view_mut(),
term.view_mut(),
out_col,
);
}
}
fn bilinear(&self, v: &Array1<f64>, u: &Array1<f64>) -> f64 {
self.bilinear_view(v.view(), u.view())
}
fn bilinear_view(&self, v: ArrayView1<'_, f64>, u: ArrayView1<'_, f64>) -> f64 {
assert_eq!(v.len(), self.p);
assert_eq!(u.len(), self.p);
let x_v = design_matrix_apply_view(&self.x_design, v);
let x_u = design_matrix_apply_view(&self.x_design, u);
let dx_v = self
.implicit_deriv
.forward_mul(self.axis, &v)
.expect("radial scalar evaluation failed during implicit hyper forward_mul");
let dx_u = self
.implicit_deriv
.forward_mul(self.axis, &u)
.expect("radial scalar evaluation failed during implicit hyper forward_mul");
let w = &*self.w_diag;
let mut design = 0.0;
for i in 0..w.len() {
design += dx_v[i] * w[i] * x_u[i];
design += dx_u[i] * w[i] * x_v[i];
}
design += self.c_correction_bilinear(&x_v, &x_u);
let penalty = dense_bilinear(&self.s_psi, v, u);
design + penalty
}
fn is_implicit(&self) -> bool {
true
}
fn as_any(&self) -> &(dyn std::any::Any + 'static) {
self
}
fn trace_projected_factor(&self, factor: &Array2<f64>) -> f64 {
assert_eq!(factor.nrows(), self.p);
let n_obs = self.w_diag.len();
let rank = factor.ncols();
if rank == 0 || n_obs == 0 {
return 0.0;
}
let xf = self.compute_xf(factor);
self.trace_projected_factor_with_xf(factor, xf.view())
}
fn trace_projected_factor_cached(
&self,
factor: &Array2<f64>,
cache: &ProjectedFactorCache,
) -> f64 {
assert_eq!(factor.nrows(), self.p);
let n_obs = self.w_diag.len();
let rank = factor.ncols();
if rank == 0 || n_obs == 0 {
return 0.0;
}
let xf = self.cached_xf(factor, cache);
self.trace_projected_factor_with_xf(factor, xf.view())
}
}
pub(crate) fn byte_balanced_row_chunk(cols: usize, n_rows: usize) -> usize {
const TARGET_BYTES: usize = 8 * 1024 * 1024;
const MIN_CHUNK_ROWS: usize = 512;
let bytes_per_row = cols.max(1) * std::mem::size_of::<f64>();
(TARGET_BYTES / bytes_per_row)
.max(MIN_CHUNK_ROWS)
.min(n_rows)
}
impl ImplicitHyperOperator {
pub(crate) fn compute_xf(&self, factor: &Array2<f64>) -> Array2<f64> {
let n_obs = self.w_diag.len();
let rank = factor.ncols();
let mut xf = Array2::<f64>::zeros((n_obs, rank));
let chunk_rows = byte_balanced_row_chunk(self.p + rank, n_obs);
let mut start = 0usize;
while start < n_obs {
let end = (start + chunk_rows).min(n_obs);
let rows = self
.x_design
.try_row_chunk(start..end)
.unwrap_or_else(|err| {
reml_contract_panic(format!(
"ImplicitHyperOperator::compute_xf row chunk failed: {err}"
))
});
let block = crate::faer_ndarray::fast_ab(&rows, factor);
xf.slice_mut(ndarray::s![start..end, ..]).assign(&block);
start = end;
}
xf
}
pub(crate) fn cached_xf(
&self,
factor: &Array2<f64>,
cache: &ProjectedFactorCache,
) -> Arc<Array2<f64>> {
let design_id = Arc::as_ptr(&self.x_design) as usize;
let key = ProjectedFactorKey::from_factor_view(design_id, factor.view());
cache.get_or_insert_with(key, || self.compute_xf(factor))
}
pub(crate) fn trace_projected_factor_with_xf(
&self,
factor: &Array2<f64>,
xf: ArrayView2<'_, f64>,
) -> f64 {
let rank = factor.ncols();
let n_obs = self.w_diag.len();
assert_eq!(xf.dim(), (n_obs, rank));
let u_knot = self.implicit_deriv.unproject_matrix(&factor.view());
let chunk_rows = byte_balanced_row_chunk(self.p + rank, n_obs);
let w = self.w_diag.as_ref();
let c_opt = self.c_x_psi_beta.as_ref().map(|arc| arc.as_ref());
let mut design_total = 0.0_f64;
let mut correction_total = 0.0_f64;
let mut start = 0usize;
while start < n_obs {
let end = (start + chunk_rows).min(n_obs);
let chunk_n = end - start;
let xf_chunk = xf.slice(ndarray::s![start..end, ..]);
let kd_chunk = self
.implicit_deriv
.row_chunk_first_raw(self.axis, start..end)
.expect("radial scalar evaluation failed during implicit hyper forward_mul_matrix");
let dxf_chunk = crate::faer_ndarray::fast_ab(&kd_chunk, &u_knot);
for i_local in 0..chunk_n {
let i = start + i_local;
let w_i = w[i];
let dxf_row = dxf_chunk.row(i_local);
let xf_row = xf_chunk.row(i_local);
for k in 0..rank {
design_total += dxf_row[k] * w_i * xf_row[k];
}
if let Some(c) = c_opt {
let c_i = c[i];
for k in 0..rank {
let v = xf_row[k];
correction_total += c_i * v * v;
}
}
}
start = end;
}
let s_f = self.s_psi.dot(factor);
let penalty: f64 = factor.iter().zip(s_f.iter()).map(|(&f, &s)| f * s).sum();
2.0 * design_total + correction_total + penalty
}
pub(crate) fn trace_projected_factor_all_axes_with_xf(
&self,
factor: &Array2<f64>,
xf: ArrayView2<'_, f64>,
axes: &[(usize, &Array2<f64>, Option<&Array1<f64>>)],
) -> Vec<f64> {
let rank = factor.ncols();
let n_obs = self.w_diag.len();
assert_eq!(xf.dim(), (n_obs, rank));
let u_knot = self.implicit_deriv.unproject_matrix(&factor.view());
let chunk_rows = byte_balanced_row_chunk(self.p + rank, n_obs.max(1));
let w = self.w_diag.as_ref();
let mut design_totals = vec![0.0_f64; axes.len()];
let mut correction_totals = vec![0.0_f64; axes.len()];
let mut start = 0usize;
while start < n_obs {
let end = (start + chunk_rows).min(n_obs);
let chunk_n = end - start;
let xf_chunk = xf.slice(ndarray::s![start..end, ..]);
for (axis_idx, (axis, _s_psi, c_opt_axis)) in axes.iter().enumerate() {
let kd_chunk = self
.implicit_deriv
.row_chunk_first_raw(*axis, start..end)
.expect(
"radial scalar evaluation failed during \
trace_projected_factor_all_axes_with_xf",
);
let dxf_chunk = crate::faer_ndarray::fast_ab(&kd_chunk, &u_knot);
for i_local in 0..chunk_n {
let i = start + i_local;
let w_i = w[i];
let dxf_row = dxf_chunk.row(i_local);
let xf_row = xf_chunk.row(i_local);
for k in 0..rank {
design_totals[axis_idx] += dxf_row[k] * w_i * xf_row[k];
}
if let Some(c) = c_opt_axis {
let c_i = c[i];
for k in 0..rank {
let v = xf_row[k];
correction_totals[axis_idx] += c_i * v * v;
}
}
}
}
start = end;
}
axes.iter()
.enumerate()
.map(|(idx, (_axis, s_psi, _c_opt_axis))| {
let s_f = s_psi.dot(factor);
let penalty: f64 = factor.iter().zip(s_f.iter()).map(|(&f, &s)| f * s).sum();
2.0 * design_totals[idx] + correction_totals[idx] + penalty
})
.collect()
}
pub(crate) fn accumulate_c_correction_xt_into(
&self,
x_col: ArrayView1<'_, f64>,
mut n_work: ArrayViewMut1<'_, f64>,
mut p_work: ArrayViewMut1<'_, f64>,
mut out_col: ArrayViewMut1<'_, f64>,
) {
let Some(c_x_psi_beta) = self.c_x_psi_beta.as_ref() else {
return;
};
let c = c_x_psi_beta.as_ref();
assert_eq!(x_col.len(), c.len());
assert_eq!(n_work.len(), c.len());
assert_eq!(p_work.len(), self.p);
for i in 0..c.len() {
n_work[i] = c[i] * x_col[i];
}
design_matrix_transpose_apply_view_into(&self.x_design, n_work.view(), p_work.view_mut());
out_col += &p_work;
}
pub(crate) fn c_correction_bilinear(&self, x_v: &Array1<f64>, x_u: &Array1<f64>) -> f64 {
let Some(c_x_psi_beta) = self.c_x_psi_beta.as_ref() else {
return 0.0;
};
x_v.iter()
.zip(x_u.iter())
.zip(c_x_psi_beta.iter())
.map(|((&xv, &xu), &c)| xv * c * xu)
.sum()
}
pub fn bilinear_with_shared_x(
&self,
x_vec: &Array1<f64>,
y_vec: &Array1<f64>,
z: &Array1<f64>,
u: &Array1<f64>,
) -> f64 {
let dx_z = self
.implicit_deriv
.forward_mul(self.axis, &z.view())
.expect("radial scalar evaluation failed during implicit hyper forward_mul");
let dx_u = self
.implicit_deriv
.forward_mul(self.axis, &u.view())
.expect("radial scalar evaluation failed during implicit hyper forward_mul");
let mut design = 0.0f64;
let w = &*self.w_diag;
for i in 0..x_vec.len() {
let wi = w[i];
design += dx_z[i] * wi * y_vec[i];
design += dx_u[i] * wi * x_vec[i];
}
if let Some(c_x_psi_beta) = self.c_x_psi_beta.as_ref() {
let c = c_x_psi_beta.as_ref();
for i in 0..x_vec.len() {
design += y_vec[i] * c[i] * x_vec[i];
}
}
let penalty = dense_bilinear(&self.s_psi, z.view(), u.view());
design + penalty
}
pub fn matvec_with_shared_xz_into(
&self,
x_vec: ArrayView1<'_, f64>,
z: ArrayView1<'_, f64>,
mut out: ArrayViewMut1<'_, f64>,
mut n_work: ArrayViewMut1<'_, f64>,
mut p_work: ArrayViewMut1<'_, f64>,
) {
assert_eq!(z.len(), self.p);
assert_eq!(out.len(), self.p);
assert_eq!(n_work.len(), self.w_diag.len());
assert_eq!(p_work.len(), self.p);
let w = &*self.w_diag;
for i in 0..w.len() {
n_work[i] = w[i] * x_vec[i];
}
let term1 = self
.implicit_deriv
.transpose_mul(self.axis, &n_work.view())
.expect("radial scalar evaluation failed during implicit hyper transpose_mul");
out.assign(&term1);
let dx_z = self
.implicit_deriv
.forward_mul(self.axis, &z)
.expect("radial scalar evaluation failed during implicit hyper forward_mul");
for i in 0..w.len() {
n_work[i] = w[i] * dx_z[i];
}
design_matrix_transpose_apply_view_into(&self.x_design, n_work.view(), p_work.view_mut());
out += &p_work;
dense_matvec_into(&self.s_psi, z, p_work.view_mut());
out += &p_work;
if let Some(c_x_psi_beta) = self.c_x_psi_beta.as_ref() {
let c = c_x_psi_beta.as_ref();
for i in 0..w.len() {
n_work[i] = c[i] * x_vec[i];
}
design_matrix_transpose_apply_view_into(
&self.x_design,
n_work.view(),
p_work.view_mut(),
);
out += &p_work;
}
}
}
pub struct SparseDirectionalHyperOperator {
pub(crate) x_tau: super::super::HyperDesignDerivative,
pub(crate) x_design: DesignMatrix,
pub(crate) w_diag: crate::matrix::SignedWeightsArc,
pub(crate) s_tau: Array2<f64>,
pub(crate) c_x_tau_beta: Option<Array1<f64>>,
pub(crate) firth_hphi_tau_partial: Option<Array2<f64>>,
pub(crate) p: usize,
}
impl HyperOperator for SparseDirectionalHyperOperator {
fn dim(&self) -> usize {
self.p
}
fn mul_vec(&self, v: &Array1<f64>) -> Array1<f64> {
assert_eq!(v.len(), self.p);
let x_v = self.x_design.matrixvectormultiply(v);
let w_x_v = &*self.w_diag * &x_v;
let term1 = self
.x_tau
.transpose_mul_original(&w_x_v)
.expect("SparseDirectionalHyperOperator transpose product should be shape-consistent");
let x_tau_v = self
.x_tau
.forward_mul_original(v)
.expect("SparseDirectionalHyperOperator forward product should be shape-consistent");
let w_x_tau_v = &*self.w_diag * &x_tau_v;
let term2 = self.x_design.transpose_vector_multiply(&w_x_tau_v);
let term3 = self.s_tau.dot(v);
let mut out = term1 + term2 + term3;
if let Some(c_x_tau_beta) = self.c_x_tau_beta.as_ref() {
let weighted = c_x_tau_beta * &x_v;
out += &self.x_design.transpose_vector_multiply(&weighted);
}
if let Some(hphi_tau_partial) = self.firth_hphi_tau_partial.as_ref() {
out -= &hphi_tau_partial.dot(v);
}
out
}
fn is_implicit(&self) -> bool {
false
}
fn as_any(&self) -> &(dyn std::any::Any + 'static) {
self
}
}
pub struct GlmCurvatureCorrectionOperator {
pub(crate) x_design: DesignMatrix,
pub(crate) neg_c_xv: Array1<f64>,
pub(crate) p: usize,
}
impl HyperOperator for GlmCurvatureCorrectionOperator {
fn dim(&self) -> usize {
self.p
}
fn mul_vec(&self, v: &Array1<f64>) -> Array1<f64> {
assert_eq!(v.len(), self.p);
let x_v = self.x_design.matrixvectormultiply(v);
let weighted = &self.neg_c_xv * &x_v;
self.x_design.transpose_vector_multiply(&weighted)
}
fn as_any(&self) -> &(dyn std::any::Any + 'static) {
self
}
fn is_implicit(&self) -> bool {
false
}
}