use super::*;
#[derive(Clone)]
pub struct HyperCoord {
pub a: f64,
pub g: Array1<f64>,
pub drift: HyperCoordDrift,
pub ld_s: f64,
pub b_depends_on_beta: bool,
pub is_penalty_like: bool,
pub firth_g: Option<Array1<f64>>,
pub tk_eta_fixed: Option<Array1<f64>>,
pub tk_x_fixed: Option<Array2<f64>>,
}
pub struct HyperCoordPair {
pub a: f64,
pub g: Array1<f64>,
pub b_mat: Array2<f64>,
pub b_operator: Option<Box<dyn HyperOperator>>,
pub ld_s: f64,
}
impl HyperCoordPair {
pub fn zero() -> Self {
Self {
a: 0.0,
g: Array1::zeros(0),
b_mat: Array2::zeros((0, 0)),
b_operator: None,
ld_s: 0.0,
}
}
}
#[derive(Clone)]
pub enum DriftDerivResult {
Dense(Array2<f64>),
Operator(Arc<dyn HyperOperator>),
}
impl std::fmt::Debug for DriftDerivResult {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Dense(matrix) => f
.debug_tuple("Dense")
.field(&format_args!("{}x{}", matrix.nrows(), matrix.ncols()))
.finish(),
Self::Operator(_) => f
.debug_tuple("Operator")
.field(&"<hyper-operator>")
.finish(),
}
}
}
impl DriftDerivResult {
pub fn into_operator(self) -> Arc<dyn HyperOperator> {
match self {
Self::Dense(matrix) => Arc::new(DenseMatrixHyperOperator { matrix }),
Self::Operator(operator) => operator,
}
}
pub fn trace_logdet(&self, hop: &dyn HessianOperator) -> f64 {
match self {
Self::Dense(matrix) => hop.trace_logdet_gradient(matrix),
Self::Operator(operator) => hop.trace_logdet_operator(operator.as_ref()),
}
}
pub fn apply(&self, v: &Array1<f64>) -> Array1<f64> {
match self {
Self::Dense(matrix) => matrix.dot(v),
Self::Operator(operator) => operator.mul_vec(v),
}
}
pub fn trace_logdet_hessian_cross(&self, rhs: &Self, hop: &dyn HessianOperator) -> f64 {
match (self, rhs) {
(Self::Dense(left), Self::Dense(right)) => hop.trace_logdet_hessian_cross(left, right),
(Self::Dense(left), Self::Operator(right)) => {
hop.trace_logdet_hessian_cross_matrix_operator(left, right.as_ref())
}
(Self::Operator(left), Self::Dense(right)) => {
hop.trace_logdet_hessian_cross_matrix_operator(right, left.as_ref())
}
(Self::Operator(left), Self::Operator(right)) => {
hop.trace_logdet_hessian_cross_operator(left.as_ref(), right.as_ref())
}
}
}
}
pub type FixedDriftDerivFn =
Box<dyn Fn(usize, &Array1<f64>) -> Option<DriftDerivResult> + Send + Sync>;
pub struct ContractedPsiSecondOrder {
pub objective: Array1<f64>,
pub score: Array2<f64>,
pub hessian: Vec<DriftDerivResult>,
pub ld_s: Array1<f64>,
}
pub type ContractedPsiSecondOrderFn =
Arc<dyn Fn(&[f64]) -> Result<Option<ContractedPsiSecondOrder>, String> + Send + Sync>;
pub trait HyperOperator: Send + Sync {
fn dim(&self) -> usize;
fn mul_vec(&self, v: &Array1<f64>) -> Array1<f64>;
fn mul_vec_view(&self, v: ArrayView1<'_, f64>) -> Array1<f64> {
self.mul_vec(&v.to_owned())
}
fn mul_vec_into(&self, v: ArrayView1<'_, f64>, mut out: ArrayViewMut1<'_, f64>) {
out.assign(&self.mul_vec_view(v));
}
fn mul_mat(&self, factor: &Array2<f64>) -> Array2<f64> {
use rayon::iter::{IntoParallelIterator, ParallelIterator};
let p = factor.nrows();
let k = factor.ncols();
let mut out = Array2::<f64>::zeros((p, k));
if rayon::current_thread_index().is_some() {
for col in 0..k {
let bv = out.column_mut(col);
self.mul_vec_into(factor.column(col), bv);
}
return out;
}
let cols: Vec<Array1<f64>> = (0..k)
.into_par_iter()
.map(|col| {
let mut bv = Array1::<f64>::zeros(p);
self.mul_vec_into(factor.column(col), bv.view_mut());
bv
})
.collect();
for (col, bv) in cols.into_iter().enumerate() {
out.column_mut(col).assign(&bv);
}
out
}
fn trace_projected_factor(&self, factor: &Array2<f64>) -> f64 {
let op_factor = self.mul_mat(factor);
factor
.iter()
.zip(op_factor.iter())
.map(|(&f, &bf)| f * bf)
.sum()
}
fn trace_projected_factor_cached(
&self,
factor: &Array2<f64>,
factor_cache: &ProjectedFactorCache,
) -> f64 {
assert!(std::mem::size_of_val(factor_cache) > 0);
self.trace_projected_factor(factor)
}
fn projected_matrix(&self, factor: &Array2<f64>) -> Array2<f64> {
let op_factor = self.mul_mat(factor);
crate::faer_ndarray::fast_atb(factor, &op_factor)
}
fn projected_matrix_cached(
&self,
factor: &Array2<f64>,
factor_cache: &ProjectedFactorCache,
) -> Array2<f64> {
assert!(std::mem::size_of_val(factor_cache) > 0);
self.projected_matrix(factor)
}
fn mul_basis_columns_into(&self, start: usize, mut out: ArrayViewMut2<'_, f64>) {
let cols = out.ncols();
let dim = out.nrows();
assert!(start + cols <= dim);
let mut basis = Array1::<f64>::zeros(dim);
for local_col in 0..cols {
let global_col = start + local_col;
basis[global_col] = 1.0;
self.mul_vec_into(basis.view(), out.column_mut(local_col));
basis[global_col] = 0.0;
}
}
fn scaled_add_mul_vec(
&self,
v: ArrayView1<'_, f64>,
scale: f64,
mut out: ArrayViewMut1<'_, f64>,
) {
if scale == 0.0 {
return;
}
let mut work = Array1::<f64>::zeros(out.len());
self.mul_vec_into(v, work.view_mut());
out.scaled_add(scale, &work);
}
fn bilinear(&self, v: &Array1<f64>, u: &Array1<f64>) -> f64 {
let mut bv = Array1::<f64>::zeros(v.len());
self.mul_vec_into(v.view(), bv.view_mut());
u.dot(&bv)
}
fn bilinear_view(&self, v: ArrayView1<'_, f64>, u: ArrayView1<'_, f64>) -> f64 {
let mut bv = Array1::<f64>::zeros(v.len());
self.mul_vec_into(v, bv.view_mut());
u.dot(&bv)
}
fn has_fast_bilinear_view(&self) -> bool {
false
}
fn to_dense(&self) -> Array2<f64> {
let p = self.dim();
let mut out = Array2::<f64>::zeros((p, p));
let mut basis = Array1::<f64>::zeros(p);
for j in 0..p {
basis[j] = 1.0;
self.mul_vec_into(basis.view(), out.column_mut(j));
basis[j] = 0.0;
}
out
}
fn is_implicit(&self) -> bool;
fn as_implicit(&self) -> Option<&ImplicitHyperOperator> {
None
}
fn as_composite(&self) -> Option<&CompositeHyperOperator> {
None
}
fn as_weighted(&self) -> Option<&WeightedHyperOperator> {
None
}
fn block_local_data(&self) -> Option<(&Array2<f64>, usize, usize)> {
None
}
fn as_sparse_directional(&self) -> Option<&SparseDirectionalHyperOperator> {
None
}
}
#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)]
pub struct ProjectedFactorKey {
pub(crate) design_id: usize,
pub(crate) factor_ptr: usize,
pub(crate) rows: usize,
pub(crate) cols: usize,
pub(crate) row_stride: isize,
pub(crate) col_stride: isize,
pub(crate) value_hash: u64,
pub(crate) value_hash2: u64,
}
impl ProjectedFactorKey {
pub fn from_factor_view(design_id: usize, factor: ArrayView2<'_, f64>) -> Self {
let strides = factor.strides();
let (value_hash, value_hash2) = projected_factor_value_fingerprint(factor);
Self {
design_id,
factor_ptr: factor.as_ptr() as usize,
rows: factor.nrows(),
cols: factor.ncols(),
row_stride: strides[0],
col_stride: strides[1],
value_hash,
value_hash2,
}
}
}
pub(crate) fn projected_factor_value_fingerprint(factor: ArrayView2<'_, f64>) -> (u64, u64) {
let mut h1 = 0xcbf2_9ce4_8422_2325_u64;
let mut h2 = 0x9e37_79b1_85eb_ca87_u64;
for (idx, value) in factor.iter().enumerate() {
let bits = value.to_bits();
let mixed = bits.wrapping_add((idx as u64).wrapping_mul(0x517c_c1b7_2722_0a95));
h1 ^= mixed;
h1 = h1.wrapping_mul(0x0000_0100_0000_01b3);
h2 ^= bits.rotate_left((idx & 63) as u32);
h2 = h2.wrapping_mul(0x94d0_49bb_1331_11eb).rotate_left(27);
}
(h1, h2)
}
pub struct ProjectedFactorCache {
pub(crate) inner: Mutex<ProjectedFactorCacheInner>,
}
pub(crate) struct ProjectedFactorCacheInner {
pub(crate) entries: HashMap<ProjectedFactorKey, ProjectedFactorEntry>,
pub(crate) in_progress: HashMap<ProjectedFactorKey, Arc<ProjectedFactorInProgress>>,
pub(crate) next_seq: u64,
pub(crate) total_bytes: usize,
pub(crate) budget_bytes: usize,
}
pub(crate) struct ProjectedFactorInProgress {
pub(crate) state: Mutex<Option<ProjectedFactorInProgressState>>,
pub(crate) ready: Condvar,
pub(crate) waiter_count: std::sync::atomic::AtomicUsize,
pub(crate) subscriber_arrived: (Mutex<()>, Condvar),
}
pub(crate) enum ProjectedFactorInProgressState {
Ready(Arc<Array2<f64>>),
Failed,
}
pub(crate) struct ProjectedFactorEntry {
pub(crate) value: Arc<Array2<f64>>,
pub(crate) bytes: usize,
pub(crate) last_used: u64,
}
impl Default for ProjectedFactorCache {
fn default() -> Self {
Self::with_budget(Self::DEFAULT_BUDGET_BYTES)
}
}
impl ProjectedFactorCache {
pub const DEFAULT_BUDGET_BYTES: usize = 2 * 1024 * 1024 * 1024;
pub fn with_budget(budget_bytes: usize) -> Self {
Self {
inner: Mutex::new(ProjectedFactorCacheInner {
entries: HashMap::new(),
in_progress: HashMap::new(),
next_seq: 0,
total_bytes: 0,
budget_bytes,
}),
}
}
pub fn get_or_insert_with(
&self,
key: ProjectedFactorKey,
compute: impl FnOnce() -> Array2<f64>,
) -> Arc<Array2<f64>> {
enum CacheLookup {
Hit(Arc<Array2<f64>>),
Wait(Arc<ProjectedFactorInProgress>),
Compute(Arc<ProjectedFactorInProgress>),
}
let lookup = {
let mut inner = self
.inner
.lock()
.expect("projected factor cache lock poisoned");
inner.next_seq += 1;
let now = inner.next_seq;
if let Some(entry) = inner.entries.get_mut(&key) {
entry.last_used = now;
CacheLookup::Hit(entry.value.clone())
} else if let Some(waiter) = inner.in_progress.get(&key) {
CacheLookup::Wait(waiter.clone())
} else {
let marker = Arc::new(ProjectedFactorInProgress {
state: Mutex::new(None),
ready: Condvar::new(),
waiter_count: std::sync::atomic::AtomicUsize::new(0),
subscriber_arrived: (Mutex::new(()), Condvar::new()),
});
inner.in_progress.insert(key, marker.clone());
CacheLookup::Compute(marker)
}
};
match lookup {
CacheLookup::Hit(value) => value,
CacheLookup::Wait(marker) => {
marker
.waiter_count
.fetch_add(1, std::sync::atomic::Ordering::AcqRel);
let (lock, cv) = &marker.subscriber_arrived;
drop(
lock.lock()
.expect("subscriber-arrived notification lock poisoned"),
);
cv.notify_all();
let mut guard = marker
.state
.lock()
.expect("projected factor in-progress lock poisoned");
let result = loop {
match guard.as_ref() {
Some(ProjectedFactorInProgressState::Ready(value)) => {
break value.clone();
}
Some(ProjectedFactorInProgressState::Failed) => {
marker
.waiter_count
.fetch_sub(1, std::sync::atomic::Ordering::AcqRel);
reml_contract_panic("projected factor cache producer panicked")
}
None => {
guard = marker
.ready
.wait(guard)
.expect("projected factor in-progress wait poisoned");
}
}
};
marker
.waiter_count
.fetch_sub(1, std::sync::atomic::Ordering::AcqRel);
result
}
CacheLookup::Compute(marker) => {
let computed = match catch_unwind(AssertUnwindSafe(|| Arc::new(compute()))) {
Ok(value) => value,
Err(payload) => {
let mut inner = self
.inner
.lock()
.expect("projected factor cache lock poisoned");
inner.in_progress.remove(&key);
drop(inner);
let mut guard = marker
.state
.lock()
.expect("projected factor in-progress lock poisoned");
*guard = Some(ProjectedFactorInProgressState::Failed);
marker.ready.notify_all();
resume_unwind(payload);
}
};
let bytes = computed.len().saturating_mul(std::mem::size_of::<f64>());
let mut inner = self
.inner
.lock()
.expect("projected factor cache lock poisoned");
inner.next_seq += 1;
let now = inner.next_seq;
if inner.budget_bytes > 0 && bytes <= inner.budget_bytes {
while inner.total_bytes.saturating_add(bytes) > inner.budget_bytes
&& !inner.entries.is_empty()
{
let Some(oldest_key) = inner
.entries
.iter()
.min_by_key(|(_, e)| e.last_used)
.map(|(k, _)| *k)
else {
break;
};
if let Some(removed) = inner.entries.remove(&oldest_key) {
inner.total_bytes = inner.total_bytes.saturating_sub(removed.bytes);
}
}
}
let value = if let Some(entry) = inner.entries.get_mut(&key) {
entry.last_used = now;
entry.value.clone()
} else {
inner.entries.insert(
key,
ProjectedFactorEntry {
value: computed.clone(),
bytes,
last_used: now,
},
);
inner.total_bytes = inner.total_bytes.saturating_add(bytes);
computed
};
inner.in_progress.remove(&key);
drop(inner);
let mut guard = marker
.state
.lock()
.expect("projected factor in-progress lock poisoned");
*guard = Some(ProjectedFactorInProgressState::Ready(value.clone()));
marker.ready.notify_all();
value
}
}
}
pub fn len(&self) -> usize {
self.inner
.lock()
.map(|inner| inner.entries.len())
.unwrap_or(0)
}
pub fn total_bytes(&self) -> usize {
self.inner
.lock()
.map(|inner| inner.total_bytes)
.unwrap_or(0)
}
pub fn is_empty(&self) -> bool {
self.len() == 0
}
}
#[derive(Clone)]
pub struct DenseMatrixHyperOperator {
pub matrix: Array2<f64>,
}
impl HyperOperator for DenseMatrixHyperOperator {
fn dim(&self) -> usize {
self.matrix.nrows()
}
fn mul_vec(&self, v: &Array1<f64>) -> Array1<f64> {
self.matrix.dot(v)
}
fn mul_vec_view(&self, v: ArrayView1<'_, f64>) -> Array1<f64> {
self.matrix.dot(&v)
}
fn mul_vec_into(&self, v: ArrayView1<'_, f64>, out: ArrayViewMut1<'_, f64>) {
dense_matvec_into(&self.matrix, v, out);
}
fn mul_basis_columns_into(&self, start: usize, mut out: ArrayViewMut2<'_, f64>) {
let end = start + out.ncols();
assert!(end <= self.matrix.ncols());
out.assign(&self.matrix.slice(ndarray::s![.., start..end]));
}
fn scaled_add_mul_vec(&self, v: ArrayView1<'_, f64>, scale: f64, out: ArrayViewMut1<'_, f64>) {
dense_matvec_scaled_add_into(&self.matrix, v, scale, out);
}
fn bilinear(&self, v: &Array1<f64>, u: &Array1<f64>) -> f64 {
dense_bilinear(&self.matrix, v.view(), u.view())
}
fn bilinear_view(&self, v: ArrayView1<'_, f64>, u: ArrayView1<'_, f64>) -> f64 {
dense_bilinear(&self.matrix, v, u)
}
fn to_dense(&self) -> Array2<f64> {
self.matrix.clone()
}
fn is_implicit(&self) -> bool {
false
}
}
#[derive(Clone)]
pub struct CompositeHyperOperator {
pub dense: Option<Array2<f64>>,
pub operators: Vec<Arc<dyn HyperOperator>>,
pub dim_hint: usize,
}
pub(crate) fn composite_trace_implicit_batched(
operators: &[Arc<dyn HyperOperator>],
factor: &Array2<f64>,
cache: Option<&ProjectedFactorCache>,
) -> f64 {
let mut trace = 0.0;
let mut group_starts: Vec<Vec<usize>> = Vec::new();
let mut handled = vec![false; operators.len()];
for (i, op) in operators.iter().enumerate() {
if handled[i] {
continue;
}
let Some(impl_i) = op.as_implicit() else {
continue;
};
let mut group = vec![i];
handled[i] = true;
for j in (i + 1)..operators.len() {
if handled[j] {
continue;
}
if let Some(impl_j) = operators[j].as_implicit()
&& Arc::ptr_eq(&impl_i.implicit_deriv, &impl_j.implicit_deriv)
&& Arc::ptr_eq(&impl_i.x_design, &impl_j.x_design)
&& Arc::ptr_eq(impl_i.w_diag.as_arc(), impl_j.w_diag.as_arc())
&& impl_i.p == impl_j.p
{
group.push(j);
handled[j] = true;
}
}
group_starts.push(group);
}
for group in &group_starts {
if group.len() >= 2 {
let lead = operators[group[0]].as_implicit().unwrap();
let xf = match cache {
Some(c) => lead.cached_xf(factor, c),
None => Arc::new(lead.compute_xf(factor)),
};
let axes: Vec<(usize, &Array2<f64>, Option<&Array1<f64>>)> = group
.iter()
.map(|&k| {
let op = operators[k].as_implicit().unwrap();
(op.axis, &op.s_psi, op.c_x_psi_beta.as_deref())
})
.collect();
let values = lead.trace_projected_factor_all_axes_with_xf(factor, xf.view(), &axes);
trace += values.iter().sum::<f64>();
} else {
let op = &operators[group[0]];
trace += match cache {
Some(c) => op.trace_projected_factor_cached(factor, c),
None => op.trace_projected_factor(factor),
};
}
}
for (i, op) in operators.iter().enumerate() {
if handled[i] {
continue;
}
trace += match cache {
Some(c) => op.trace_projected_factor_cached(factor, c),
None => op.trace_projected_factor(factor),
};
}
trace
}
pub(crate) fn trace_projected_factors_batched(
operators: &[Arc<dyn HyperOperator>],
factor: &Array2<f64>,
cache: &ProjectedFactorCache,
) -> Vec<f64> {
let mut out = vec![0.0; operators.len()];
let mut handled = vec![false; operators.len()];
for i in 0..operators.len() {
if handled[i] {
continue;
}
let Some(impl_i) = operators[i].as_implicit() else {
out[i] = operators[i].trace_projected_factor_cached(factor, cache);
handled[i] = true;
continue;
};
let mut group = vec![i];
handled[i] = true;
for j in (i + 1)..operators.len() {
if handled[j] {
continue;
}
if let Some(impl_j) = operators[j].as_implicit()
&& Arc::ptr_eq(&impl_i.implicit_deriv, &impl_j.implicit_deriv)
&& Arc::ptr_eq(&impl_i.x_design, &impl_j.x_design)
&& Arc::ptr_eq(impl_i.w_diag.as_arc(), impl_j.w_diag.as_arc())
&& impl_i.p == impl_j.p
{
group.push(j);
handled[j] = true;
}
}
if group.len() >= 2 {
let xf = impl_i.cached_xf(factor, cache);
let axes: Vec<(usize, &Array2<f64>, Option<&Array1<f64>>)> = group
.iter()
.map(|&idx| {
let op = operators[idx].as_implicit().unwrap();
(op.axis, &op.s_psi, op.c_x_psi_beta.as_deref())
})
.collect();
let values = impl_i.trace_projected_factor_all_axes_with_xf(factor, xf.view(), &axes);
for (&idx, value) in group.iter().zip(values) {
out[idx] = value;
}
} else {
out[i] = operators[i].trace_projected_factor_cached(factor, cache);
}
}
out
}
pub(crate) fn collect_projected_trace_terms<'a>(
out_idx: usize,
weight: f64,
op: &'a dyn HyperOperator,
factor: &Array2<f64>,
dense_acc: &mut [f64],
terms: &mut Vec<(usize, f64, &'a dyn HyperOperator)>,
) {
if weight == 0.0 {
return;
}
if let Some(composite) = op.as_composite() {
if let Some(dense) = composite.dense.as_ref() {
dense_acc[out_idx] += weight * dense_trace_projected_factor(dense, factor);
}
for inner in &composite.operators {
collect_projected_trace_terms(
out_idx,
weight,
inner.as_ref(),
factor,
dense_acc,
terms,
);
}
} else if let Some(weighted) = op.as_weighted() {
for (term_weight, inner) in &weighted.terms {
collect_projected_trace_terms(
out_idx,
weight * *term_weight,
inner.as_ref(),
factor,
dense_acc,
terms,
);
}
} else {
terms.push((out_idx, weight, op));
}
}
pub(crate) fn collect_projected_matrix_terms<'a>(
out_idx: usize,
weight: f64,
op: &'a dyn HyperOperator,
factor: &Array2<f64>,
dense_acc: &mut [Array2<f64>],
terms: &mut Vec<(usize, f64, &'a dyn HyperOperator)>,
) {
if weight == 0.0 {
return;
}
if let Some(composite) = op.as_composite() {
if let Some(dense) = composite.dense.as_ref() {
dense_acc[out_idx].scaled_add(weight, &dense_projected_matrix(dense, factor));
}
for inner in &composite.operators {
collect_projected_matrix_terms(
out_idx,
weight,
inner.as_ref(),
factor,
dense_acc,
terms,
);
}
} else if let Some(weighted) = op.as_weighted() {
for (term_weight, inner) in &weighted.terms {
collect_projected_matrix_terms(
out_idx,
weight * *term_weight,
inner.as_ref(),
factor,
dense_acc,
terms,
);
}
} else {
terms.push((out_idx, weight, op));
}
}
pub(crate) fn trace_projected_operator_terms_batched(
n_out: usize,
terms: &[(usize, f64, &dyn HyperOperator)],
factor: &Array2<f64>,
cache: &ProjectedFactorCache,
) -> Vec<f64> {
let mut out = vec![0.0_f64; n_out];
let mut handled = vec![false; terms.len()];
for i in 0..terms.len() {
if handled[i] {
continue;
}
let Some(impl_i) = terms[i].2.as_implicit() else {
continue;
};
let mut group = vec![i];
handled[i] = true;
for j in (i + 1)..terms.len() {
if handled[j] {
continue;
}
if let Some(impl_j) = terms[j].2.as_implicit()
&& Arc::ptr_eq(&impl_i.implicit_deriv, &impl_j.implicit_deriv)
&& Arc::ptr_eq(&impl_i.x_design, &impl_j.x_design)
&& Arc::ptr_eq(impl_i.w_diag.as_arc(), impl_j.w_diag.as_arc())
&& impl_i.p == impl_j.p
{
group.push(j);
handled[j] = true;
}
}
let lead = terms[group[0]].2.as_implicit().unwrap();
let xf = lead.cached_xf(factor, cache);
let axes: Vec<(usize, &Array2<f64>, Option<&Array1<f64>>)> = group
.iter()
.map(|&term_idx| {
let op = terms[term_idx].2.as_implicit().unwrap();
(op.axis, &op.s_psi, op.c_x_psi_beta.as_deref())
})
.collect();
let values = lead.trace_projected_factor_all_axes_with_xf(factor, xf.view(), &axes);
for (&term_idx, value) in group.iter().zip(values.iter()) {
let (out_idx, weight, _) = terms[term_idx];
out[out_idx] += weight * *value;
}
}
for (i, (out_idx, weight, op)) in terms.iter().enumerate() {
if handled[i] {
continue;
}
out[*out_idx] += *weight * op.trace_projected_factor_cached(factor, cache);
}
out
}
pub(crate) fn projected_operator_terms_batched(
n_out: usize,
terms: &[(usize, f64, &dyn HyperOperator)],
factor: &Array2<f64>,
cache: &ProjectedFactorCache,
) -> Vec<Array2<f64>> {
let rank = factor.ncols();
let mut out: Vec<Array2<f64>> = (0..n_out)
.map(|_| Array2::<f64>::zeros((rank, rank)))
.collect();
for (out_idx, weight, op) in terms.iter() {
let projected = op.projected_matrix_cached(factor, cache);
out[*out_idx].scaled_add(*weight, &projected);
}
out
}
pub(crate) fn project_hyper_operators_batched(
n_out: usize,
terms: &[(usize, f64, &dyn HyperOperator)],
factor: &Array2<f64>,
cache: &ProjectedFactorCache,
) -> Vec<Array2<f64>> {
projected_operator_terms_batched(n_out, terms, factor, cache)
}
pub(crate) fn trace_logdet_drifts_projected_factor_batched(
drifts: &[DriftDerivResult],
factor: &Array2<f64>,
cache: &ProjectedFactorCache,
) -> Vec<f64> {
let mut out = vec![0.0_f64; drifts.len()];
let mut terms: Vec<(usize, f64, &dyn HyperOperator)> = Vec::new();
for (idx, drift) in drifts.iter().enumerate() {
match drift {
DriftDerivResult::Dense(matrix) => {
out[idx] += dense_trace_projected_factor(matrix, factor);
}
DriftDerivResult::Operator(op) => {
collect_projected_trace_terms(idx, 1.0, op.as_ref(), factor, &mut out, &mut terms);
}
}
}
let batched = trace_projected_operator_terms_batched(drifts.len(), &terms, factor, cache);
for (dst, value) in out.iter_mut().zip(batched) {
*dst += value;
}
out
}
pub(crate) fn dense_spectral_trace_logdet_drifts_batched(
ds: &DenseSpectralOperator,
drifts: &[DriftDerivResult],
) -> Vec<f64> {
trace_logdet_drifts_projected_factor_batched(drifts, &ds.g_factor, &ds.projected_factor_cache)
}
pub(crate) fn penalty_subspace_trace_factor(kernel: &PenaltySubspaceTrace) -> Array2<f64> {
let (evals, evecs) = kernel
.h_proj_inverse
.eigh(faer::Side::Lower)
.expect("PenaltySubspaceTrace kernel factor eigendecomposition failed");
let r = evals.len();
let mut root = evecs.clone();
for col in 0..r {
let scale = evals[col].max(0.0).sqrt();
for row in 0..r {
root[[row, col]] *= scale;
}
}
crate::faer_ndarray::fast_ab(&kernel.u_s, &root)
}
pub(crate) fn penalty_subspace_trace_drifts_batched(
kernel: &PenaltySubspaceTrace,
drifts: &[DriftDerivResult],
) -> Vec<f64> {
let factor = penalty_subspace_trace_factor(kernel);
let cache = ProjectedFactorCache::default();
trace_logdet_drifts_projected_factor_batched(drifts, &factor, &cache)
}
pub(crate) fn penalty_subspace_reduce_drifts_batched(
kernel: &PenaltySubspaceTrace,
drifts: &[DriftDerivResult],
) -> Vec<Array2<f64>> {
drifts
.iter()
.map(|drift| match drift {
DriftDerivResult::Dense(matrix) => kernel.reduce(matrix),
DriftDerivResult::Operator(op) => kernel.reduce_operator(op.as_ref()),
})
.collect()
}
pub(crate) fn dense_spectral_trace_logdet_operators_batched(
ds: &DenseSpectralOperator,
operators: &[Arc<dyn HyperOperator>],
) -> Vec<f64> {
if operators.is_empty() {
return Vec::new();
}
if log::log_enabled!(log::Level::Info) {
let start = std::time::Instant::now();
let out =
trace_projected_factors_batched(operators, &ds.g_factor, &ds.projected_factor_cache);
let implicit_count = operators.iter().filter(|op| op.is_implicit()).count();
dense_spectral_stage_log(
&format!(
"DenseSpectralOperator::trace_logdet_operators_batched dim={} rank={} ops={} implicit_ops={}",
ds.n_dim,
ds.g_factor.ncols(),
operators.len(),
implicit_count,
),
start.elapsed().as_secs_f64(),
);
out
} else {
trace_projected_factors_batched(operators, &ds.g_factor, &ds.projected_factor_cache)
}
}
impl HyperOperator for CompositeHyperOperator {
fn as_composite(&self) -> Option<&CompositeHyperOperator> {
Some(self)
}
fn dim(&self) -> usize {
self.dim_hint
}
fn mul_vec(&self, v: &Array1<f64>) -> Array1<f64> {
let mut out = Array1::<f64>::zeros(v.len());
self.mul_vec_into(v.view(), out.view_mut());
out
}
fn mul_vec_view(&self, v: ArrayView1<'_, f64>) -> Array1<f64> {
let mut out = Array1::<f64>::zeros(v.len());
self.mul_vec_into(v, out.view_mut());
out
}
fn mul_vec_into(&self, v: ArrayView1<'_, f64>, mut out: ArrayViewMut1<'_, f64>) {
if self.dense.is_none() && self.operators.len() == 1 {
self.operators[0].mul_vec_into(v, out);
return;
}
out.fill(0.0);
if let Some(dense) = self.dense.as_ref() {
dense_matvec_into(dense, v, out.view_mut());
}
for op in &self.operators {
op.scaled_add_mul_vec(v, 1.0, out.view_mut());
}
}
fn mul_basis_columns_into(&self, start: usize, mut out: ArrayViewMut2<'_, f64>) {
if self.dense.is_none() && self.operators.len() == 1 {
self.operators[0].mul_basis_columns_into(start, out);
return;
}
out.fill(0.0);
let cols = out.ncols();
let end = start + cols;
if let Some(dense) = self.dense.as_ref() {
out += &dense.slice(ndarray::s![.., start..end]);
}
let mut work = Array2::<f64>::zeros((out.nrows(), cols));
for op in &self.operators {
op.mul_basis_columns_into(start, work.view_mut());
out += &work;
}
}
fn scaled_add_mul_vec(
&self,
v: ArrayView1<'_, f64>,
scale: f64,
mut out: ArrayViewMut1<'_, f64>,
) {
if scale == 0.0 {
return;
}
if self.dense.is_none() && self.operators.len() == 1 {
self.operators[0].scaled_add_mul_vec(v, scale, out);
return;
}
if let Some(dense) = self.dense.as_ref() {
dense_matvec_scaled_add_into(dense, v, scale, out.view_mut());
}
for op in &self.operators {
op.scaled_add_mul_vec(v, scale, out.view_mut());
}
}
fn mul_mat(&self, factor: &Array2<f64>) -> Array2<f64> {
if self.dense.is_none() && self.operators.len() == 1 {
return self.operators[0].mul_mat(factor);
}
let p = factor.nrows();
let k = factor.ncols();
let mut out = Array2::<f64>::zeros((p, k));
if let Some(dense) = self.dense.as_ref() {
out += &dense.dot(factor);
}
for op in &self.operators {
out += &op.mul_mat(factor);
}
out
}
fn trace_projected_factor(&self, factor: &Array2<f64>) -> f64 {
if self.dense.is_none() && self.operators.len() == 1 {
return self.operators[0].trace_projected_factor(factor);
}
let mut trace = 0.0;
if let Some(dense) = self.dense.as_ref() {
let dense_factor = dense.dot(factor);
trace += factor
.iter()
.zip(dense_factor.iter())
.map(|(&f, &bf)| f * bf)
.sum::<f64>();
}
trace += composite_trace_implicit_batched(&self.operators, factor, None);
trace
}
fn trace_projected_factor_cached(
&self,
factor: &Array2<f64>,
cache: &ProjectedFactorCache,
) -> f64 {
if self.dense.is_none() && self.operators.len() == 1 {
return self.operators[0].trace_projected_factor_cached(factor, cache);
}
let mut trace = 0.0;
if let Some(dense) = self.dense.as_ref() {
let dense_factor = dense.dot(factor);
trace += factor
.iter()
.zip(dense_factor.iter())
.map(|(&f, &bf)| f * bf)
.sum::<f64>();
}
trace += composite_trace_implicit_batched(&self.operators, factor, Some(cache));
trace
}
fn projected_matrix(&self, factor: &Array2<f64>) -> Array2<f64> {
if self.dense.is_none() && self.operators.len() == 1 {
return self.operators[0].projected_matrix(factor);
}
let rank = factor.ncols();
let mut projected = Array2::<f64>::zeros((rank, rank));
if let Some(dense) = self.dense.as_ref() {
let mf = crate::faer_ndarray::fast_ab(dense, factor);
projected += &crate::faer_ndarray::fast_atb(factor, &mf);
}
for op in &self.operators {
projected += &op.projected_matrix(factor);
}
projected
}
fn projected_matrix_cached(
&self,
factor: &Array2<f64>,
cache: &ProjectedFactorCache,
) -> Array2<f64> {
if self.dense.is_none() && self.operators.len() == 1 {
return self.operators[0].projected_matrix_cached(factor, cache);
}
let rank = factor.ncols();
let mut projected = Array2::<f64>::zeros((rank, rank));
if let Some(dense) = self.dense.as_ref() {
let mf = crate::faer_ndarray::fast_ab(dense, factor);
projected += &crate::faer_ndarray::fast_atb(factor, &mf);
}
for op in &self.operators {
projected += &op.projected_matrix_cached(factor, cache);
}
projected
}
fn bilinear(&self, v: &Array1<f64>, u: &Array1<f64>) -> f64 {
let mut total = 0.0;
if let Some(dense) = self.dense.as_ref() {
total += dense_bilinear(dense, v.view(), u.view());
}
for op in &self.operators {
total += op.bilinear(v, u);
}
total
}
fn bilinear_view(&self, v: ArrayView1<'_, f64>, u: ArrayView1<'_, f64>) -> f64 {
let mut total = 0.0;
if let Some(dense) = self.dense.as_ref() {
total += dense_bilinear(dense, v, u);
}
for op in &self.operators {
total += op.bilinear_view(v, u);
}
total
}
fn to_dense(&self) -> Array2<f64> {
let mut out = self
.dense
.clone()
.unwrap_or_else(|| Array2::<f64>::zeros((self.dim_hint, self.dim_hint)));
for op in &self.operators {
out += &op.to_dense();
}
out
}
fn is_implicit(&self) -> bool {
self.operators.iter().any(|op| op.is_implicit())
}
}
#[derive(Clone)]
pub struct BlockLocalDrift {
pub local: Array2<f64>,
pub start: usize,
pub end: usize,
pub total_dim: usize,
}
impl HyperOperator for BlockLocalDrift {
fn dim(&self) -> usize {
self.total_dim
}
fn mul_vec(&self, v: &Array1<f64>) -> Array1<f64> {
let mut out = Array1::zeros(v.len());
self.mul_vec_into(v.view(), out.view_mut());
out
}
fn mul_vec_view(&self, v: ArrayView1<'_, f64>) -> Array1<f64> {
let mut out = Array1::zeros(v.len());
self.mul_vec_into(v, out.view_mut());
out
}
fn mul_vec_into(&self, v: ArrayView1<'_, f64>, mut out: ArrayViewMut1<'_, f64>) {
out.fill(0.0);
let v_block = v.slice(ndarray::s![self.start..self.end]);
let out_block = out.slice_mut(ndarray::s![self.start..self.end]);
dense_matvec_into(&self.local, v_block, out_block);
}
fn mul_basis_columns_into(&self, start: usize, mut out: ArrayViewMut2<'_, f64>) {
out.fill(0.0);
let global_end = start + out.ncols();
let col_start = start.max(self.start);
let col_end = global_end.min(self.end);
if col_start >= col_end {
return;
}
let local_col_start = col_start - self.start;
let local_col_end = col_end - self.start;
let out_col_start = col_start - start;
let out_col_end = col_end - start;
out.slice_mut(ndarray::s![
self.start..self.end,
out_col_start..out_col_end
])
.assign(
&self
.local
.slice(ndarray::s![.., local_col_start..local_col_end]),
);
}
fn scaled_add_mul_vec(
&self,
v: ArrayView1<'_, f64>,
scale: f64,
mut out: ArrayViewMut1<'_, f64>,
) {
if scale == 0.0 {
return;
}
let v_block = v.slice(ndarray::s![self.start..self.end]);
let out_block = out.slice_mut(ndarray::s![self.start..self.end]);
dense_matvec_scaled_add_into(&self.local, v_block, scale, out_block);
}
fn bilinear(&self, v: &Array1<f64>, u: &Array1<f64>) -> f64 {
let v_block = v.slice(ndarray::s![self.start..self.end]);
let u_block = u.slice(ndarray::s![self.start..self.end]);
u_block.dot(&self.local.dot(&v_block))
}
fn bilinear_view(&self, v: ArrayView1<'_, f64>, u: ArrayView1<'_, f64>) -> f64 {
let v_block = v.slice(ndarray::s![self.start..self.end]);
let u_block = u.slice(ndarray::s![self.start..self.end]);
let mut total = 0.0;
for (row, u_value) in self.local.rows().into_iter().zip(u_block.iter().copied()) {
let mut row_dot = 0.0;
for (entry, v_value) in row.iter().copied().zip(v_block.iter().copied()) {
row_dot += entry * v_value;
}
total += u_value * row_dot;
}
total
}
fn to_dense(&self) -> Array2<f64> {
let p = self.total_dim;
let mut out = Array2::zeros((p, p));
out.slice_mut(ndarray::s![self.start..self.end, self.start..self.end])
.assign(&self.local);
out
}
fn is_implicit(&self) -> bool {
false
}
fn block_local_data(&self) -> Option<(&Array2<f64>, usize, usize)> {
Some((&self.local, self.start, self.end))
}
}
#[derive(Clone)]
pub struct HyperCoordDrift {
pub dense: Option<Array2<f64>>,
pub block_local: Option<BlockLocalDrift>,
pub operator: Option<Arc<dyn HyperOperator>>,
}
impl HyperCoordDrift {
pub fn none() -> Self {
Self {
dense: None,
block_local: None,
operator: None,
}
}
pub fn from_dense(dense: Array2<f64>) -> Self {
Self {
dense: Some(dense),
block_local: None,
operator: None,
}
}
pub fn from_operator(operator: Arc<dyn HyperOperator>) -> Self {
Self {
dense: None,
block_local: None,
operator: Some(operator),
}
}
pub fn from_parts(
dense: Option<Array2<f64>>,
operator: Option<Arc<dyn HyperOperator>>,
) -> Self {
let dense = dense.filter(|mat| !(operator.is_some() && mat.is_empty()));
Self {
dense,
block_local: None,
operator,
}
}
pub fn from_block_local_and_operator(
local: Array2<f64>,
start: usize,
end: usize,
total_dim: usize,
operator: Option<Arc<dyn HyperOperator>>,
) -> Self {
Self {
dense: None,
block_local: Some(BlockLocalDrift {
local,
start,
end,
total_dim,
}),
operator,
}
}
pub fn has_operator(&self) -> bool {
self.operator.is_some()
}
pub fn uses_operator_fast_path(&self) -> bool {
self.operator.is_some() || self.block_local.is_some()
}
pub fn operator_ref(&self) -> Option<&dyn HyperOperator> {
self.operator.as_ref().map(Arc::as_ref)
}
pub fn materialize(&self) -> Array2<f64> {
let p = self.infer_dim();
if p == 0 {
return Array2::zeros((0, 0));
}
let mut out = self.dense.clone().unwrap_or_else(|| Array2::zeros((p, p)));
if let Some(bl) = &self.block_local {
out.slice_mut(ndarray::s![bl.start..bl.end, bl.start..bl.end])
.scaled_add(1.0, &bl.local);
}
if let Some(op) = &self.operator {
out += &op.to_dense();
}
out
}
pub fn apply(&self, v: &Array1<f64>) -> Array1<f64> {
let mut out = Array1::zeros(v.len());
self.scaled_add_apply(v.view(), 1.0, &mut out);
out
}
pub fn scaled_add_apply(&self, v: ArrayView1<'_, f64>, scale: f64, out: &mut Array1<f64>) {
assert_eq!(v.len(), out.len());
if scale == 0.0 {
return;
}
if let Some(dense) = &self.dense {
dense_matvec_scaled_add_into(dense, v, scale, out.view_mut());
}
if let Some(bl) = &self.block_local {
let v_block = v.slice(ndarray::s![bl.start..bl.end]);
let out_block = out.slice_mut(ndarray::s![bl.start..bl.end]);
dense_matvec_scaled_add_into(&bl.local, v_block, scale, out_block);
}
if let Some(op) = &self.operator {
op.scaled_add_mul_vec(v, scale, out.view_mut());
}
}
pub(crate) fn infer_dim(&self) -> usize {
if let Some(d) = &self.dense {
return d.nrows();
}
if let Some(op) = &self.operator {
return op.dim();
}
if let Some(bl) = &self.block_local {
return bl.total_dim;
}
0
}
}
mod implicit_matvec_scratch {
use std::cell::RefCell;
pub(super) struct Scratch {
pub x_v: Vec<f64>,
pub n_work: Vec<f64>,
pub p_work: Vec<f64>,
}
impl Scratch {
pub(crate) const fn new() -> Self {
Self {
x_v: Vec::new(),
n_work: Vec::new(),
p_work: Vec::new(),
}
}
}
thread_local! {
static SCRATCH: RefCell<Scratch> = const { RefCell::new(Scratch::new()) };
}
pub(super) fn with<R>(f: impl FnOnce(&mut Scratch) -> R) -> R {
SCRATCH.with(|cell| f(&mut cell.borrow_mut()))
}
}
pub struct ImplicitHyperOperator {
pub implicit_deriv: std::sync::Arc<crate::terms::basis::ImplicitDesignPsiDerivative>,
pub axis: usize,
pub(crate) x_design: std::sync::Arc<DesignMatrix>,
pub(crate) w_diag: crate::matrix::SignedWeightsArc,
pub s_psi: Array2<f64>,
pub(crate) p: usize,
pub c_x_psi_beta: Option<std::sync::Arc<Array1<f64>>>,
}
impl HyperOperator for ImplicitHyperOperator {
fn dim(&self) -> usize {
self.p
}
fn mul_vec(&self, v: &Array1<f64>) -> Array1<f64> {
let mut out = Array1::<f64>::zeros(self.p);
self.mul_vec_into(v.view(), out.view_mut());
out
}
fn mul_vec_view(&self, v: ArrayView1<'_, f64>) -> Array1<f64> {
let mut out = Array1::<f64>::zeros(self.p);
self.mul_vec_into(v, out.view_mut());
out
}
fn mul_vec_into(&self, v: ArrayView1<'_, f64>, out: ArrayViewMut1<'_, f64>) {
assert_eq!(v.len(), self.p);
let n_obs = self.w_diag.len();
implicit_matvec_scratch::with(|s| {
s.x_v.clear();
s.x_v.resize(n_obs, 0.0);
s.n_work.clear();
s.n_work.resize(n_obs, 0.0);
s.p_work.clear();
s.p_work.resize(self.p, 0.0);
let mut x_v_view = ndarray::ArrayViewMut1::from(s.x_v.as_mut_slice());
let n_work_view = ndarray::ArrayViewMut1::from(s.n_work.as_mut_slice());
let p_work_view = ndarray::ArrayViewMut1::from(s.p_work.as_mut_slice());
design_matrix_apply_view_into(&self.x_design, v, x_v_view.view_mut());
self.matvec_with_shared_xz_into(x_v_view.view(), v, out, n_work_view, p_work_view);
});
}
fn mul_basis_columns_into(&self, start: usize, mut out: ArrayViewMut2<'_, f64>) {
let cols = out.ncols();
assert!(start + cols <= self.p);
let n_obs = self.w_diag.len();
let mut basis = Array1::<f64>::zeros(self.p);
let mut x_col = Array1::<f64>::zeros(n_obs);
let mut dx_col = Array1::<f64>::zeros(n_obs);
let mut weighted = Array1::<f64>::zeros(n_obs);
let mut term = Array1::<f64>::zeros(self.p);
for local_col in 0..cols {
let global_col = start + local_col;
let mut out_col = out.column_mut(local_col);
out_col.assign(&self.s_psi.column(global_col));
design_matrix_column_into(&self.x_design, global_col, x_col.view_mut());
Zip::from(weighted.view_mut())
.and(self.w_diag.view())
.and(x_col.view())
.par_for_each(|dst, &w, &x| *dst = w * x);
term.assign(
&self
.implicit_deriv
.transpose_mul(self.axis, &weighted.view())
.expect("radial scalar evaluation failed during implicit hyper transpose_mul"),
);
out_col += &term;
basis[global_col] = 1.0;
dx_col.assign(
&self
.implicit_deriv
.forward_mul(self.axis, &basis.view())
.expect("radial scalar evaluation failed during implicit hyper forward_mul"),
);
basis[global_col] = 0.0;
Zip::from(weighted.view_mut())
.and(self.w_diag.view())
.and(dx_col.view())
.par_for_each(|dst, &w, &dx| *dst = w * dx);
design_matrix_transpose_apply_view_into(
&self.x_design,
weighted.view(),
term.view_mut(),
);
out_col += &term;
self.accumulate_c_correction_xt_into(
x_col.view(),
weighted.view_mut(),
term.view_mut(),
out_col,
);
}
}
fn bilinear(&self, v: &Array1<f64>, u: &Array1<f64>) -> f64 {
self.bilinear_view(v.view(), u.view())
}
fn bilinear_view(&self, v: ArrayView1<'_, f64>, u: ArrayView1<'_, f64>) -> f64 {
assert_eq!(v.len(), self.p);
assert_eq!(u.len(), self.p);
let x_v = design_matrix_apply_view(&self.x_design, v);
let x_u = design_matrix_apply_view(&self.x_design, u);
let dx_v = self
.implicit_deriv
.forward_mul(self.axis, &v)
.expect("radial scalar evaluation failed during implicit hyper forward_mul");
let dx_u = self
.implicit_deriv
.forward_mul(self.axis, &u)
.expect("radial scalar evaluation failed during implicit hyper forward_mul");
let w = &*self.w_diag;
let mut design = 0.0;
for i in 0..w.len() {
design += dx_v[i] * w[i] * x_u[i];
design += dx_u[i] * w[i] * x_v[i];
}
design += self.c_correction_bilinear(&x_v, &x_u);
let penalty = dense_bilinear(&self.s_psi, v, u);
design + penalty
}
fn is_implicit(&self) -> bool {
true
}
fn as_implicit(&self) -> Option<&ImplicitHyperOperator> {
Some(self)
}
fn trace_projected_factor(&self, factor: &Array2<f64>) -> f64 {
assert_eq!(factor.nrows(), self.p);
let n_obs = self.w_diag.len();
let rank = factor.ncols();
if rank == 0 || n_obs == 0 {
return 0.0;
}
let xf = self.compute_xf(factor);
self.trace_projected_factor_with_xf(factor, xf.view())
}
fn trace_projected_factor_cached(
&self,
factor: &Array2<f64>,
cache: &ProjectedFactorCache,
) -> f64 {
assert_eq!(factor.nrows(), self.p);
let n_obs = self.w_diag.len();
let rank = factor.ncols();
if rank == 0 || n_obs == 0 {
return 0.0;
}
let xf = self.cached_xf(factor, cache);
self.trace_projected_factor_with_xf(factor, xf.view())
}
}
pub(crate) fn byte_balanced_row_chunk(cols: usize, n_rows: usize) -> usize {
const TARGET_BYTES: usize = 8 * 1024 * 1024;
const MIN_CHUNK_ROWS: usize = 512;
let bytes_per_row = cols.max(1) * std::mem::size_of::<f64>();
(TARGET_BYTES / bytes_per_row)
.max(MIN_CHUNK_ROWS)
.min(n_rows)
}
impl ImplicitHyperOperator {
pub(crate) fn compute_xf(&self, factor: &Array2<f64>) -> Array2<f64> {
let n_obs = self.w_diag.len();
let rank = factor.ncols();
let mut xf = Array2::<f64>::zeros((n_obs, rank));
let chunk_rows = byte_balanced_row_chunk(self.p + rank, n_obs);
let mut start = 0usize;
while start < n_obs {
let end = (start + chunk_rows).min(n_obs);
let rows = self
.x_design
.try_row_chunk(start..end)
.unwrap_or_else(|err| {
reml_contract_panic(format!(
"ImplicitHyperOperator::compute_xf row chunk failed: {err}"
))
});
let block = crate::faer_ndarray::fast_ab(&rows, factor);
xf.slice_mut(ndarray::s![start..end, ..]).assign(&block);
start = end;
}
xf
}
pub(crate) fn cached_xf(
&self,
factor: &Array2<f64>,
cache: &ProjectedFactorCache,
) -> Arc<Array2<f64>> {
let design_id = Arc::as_ptr(&self.x_design) as usize;
let key = ProjectedFactorKey::from_factor_view(design_id, factor.view());
cache.get_or_insert_with(key, || self.compute_xf(factor))
}
pub(crate) fn trace_projected_factor_with_xf(
&self,
factor: &Array2<f64>,
xf: ArrayView2<'_, f64>,
) -> f64 {
let rank = factor.ncols();
let n_obs = self.w_diag.len();
assert_eq!(xf.dim(), (n_obs, rank));
let u_knot = self.implicit_deriv.unproject_matrix(&factor.view());
let chunk_rows = byte_balanced_row_chunk(self.p + rank, n_obs);
let w = self.w_diag.as_ref();
let c_opt = self.c_x_psi_beta.as_ref().map(|arc| arc.as_ref());
let mut design_total = 0.0_f64;
let mut correction_total = 0.0_f64;
let mut start = 0usize;
while start < n_obs {
let end = (start + chunk_rows).min(n_obs);
let chunk_n = end - start;
let xf_chunk = xf.slice(ndarray::s![start..end, ..]);
let kd_chunk = self
.implicit_deriv
.row_chunk_first_raw(self.axis, start..end)
.expect("radial scalar evaluation failed during implicit hyper forward_mul_matrix");
let dxf_chunk = crate::faer_ndarray::fast_ab(&kd_chunk, &u_knot);
for i_local in 0..chunk_n {
let i = start + i_local;
let w_i = w[i];
let dxf_row = dxf_chunk.row(i_local);
let xf_row = xf_chunk.row(i_local);
for k in 0..rank {
design_total += dxf_row[k] * w_i * xf_row[k];
}
if let Some(c) = c_opt {
let c_i = c[i];
for k in 0..rank {
let v = xf_row[k];
correction_total += c_i * v * v;
}
}
}
start = end;
}
let s_f = self.s_psi.dot(factor);
let penalty: f64 = factor.iter().zip(s_f.iter()).map(|(&f, &s)| f * s).sum();
2.0 * design_total + correction_total + penalty
}
pub(crate) fn trace_projected_factor_all_axes_with_xf(
&self,
factor: &Array2<f64>,
xf: ArrayView2<'_, f64>,
axes: &[(usize, &Array2<f64>, Option<&Array1<f64>>)],
) -> Vec<f64> {
let rank = factor.ncols();
let n_obs = self.w_diag.len();
assert_eq!(xf.dim(), (n_obs, rank));
let u_knot = self.implicit_deriv.unproject_matrix(&factor.view());
let chunk_rows = byte_balanced_row_chunk(self.p + rank, n_obs.max(1));
let w = self.w_diag.as_ref();
let mut design_totals = vec![0.0_f64; axes.len()];
let mut correction_totals = vec![0.0_f64; axes.len()];
let mut start = 0usize;
while start < n_obs {
let end = (start + chunk_rows).min(n_obs);
let chunk_n = end - start;
let xf_chunk = xf.slice(ndarray::s![start..end, ..]);
for (axis_idx, (axis, _s_psi, c_opt_axis)) in axes.iter().enumerate() {
let kd_chunk = self
.implicit_deriv
.row_chunk_first_raw(*axis, start..end)
.expect(
"radial scalar evaluation failed during \
trace_projected_factor_all_axes_with_xf",
);
let dxf_chunk = crate::faer_ndarray::fast_ab(&kd_chunk, &u_knot);
for i_local in 0..chunk_n {
let i = start + i_local;
let w_i = w[i];
let dxf_row = dxf_chunk.row(i_local);
let xf_row = xf_chunk.row(i_local);
for k in 0..rank {
design_totals[axis_idx] += dxf_row[k] * w_i * xf_row[k];
}
if let Some(c) = c_opt_axis {
let c_i = c[i];
for k in 0..rank {
let v = xf_row[k];
correction_totals[axis_idx] += c_i * v * v;
}
}
}
}
start = end;
}
axes.iter()
.enumerate()
.map(|(idx, (_axis, s_psi, _c_opt_axis))| {
let s_f = s_psi.dot(factor);
let penalty: f64 = factor.iter().zip(s_f.iter()).map(|(&f, &s)| f * s).sum();
2.0 * design_totals[idx] + correction_totals[idx] + penalty
})
.collect()
}
pub(crate) fn accumulate_c_correction_xt_into(
&self,
x_col: ArrayView1<'_, f64>,
mut n_work: ArrayViewMut1<'_, f64>,
mut p_work: ArrayViewMut1<'_, f64>,
mut out_col: ArrayViewMut1<'_, f64>,
) {
let Some(c_x_psi_beta) = self.c_x_psi_beta.as_ref() else {
return;
};
let c = c_x_psi_beta.as_ref();
assert_eq!(x_col.len(), c.len());
assert_eq!(n_work.len(), c.len());
assert_eq!(p_work.len(), self.p);
for i in 0..c.len() {
n_work[i] = c[i] * x_col[i];
}
design_matrix_transpose_apply_view_into(&self.x_design, n_work.view(), p_work.view_mut());
out_col += &p_work;
}
pub(crate) fn c_correction_bilinear(&self, x_v: &Array1<f64>, x_u: &Array1<f64>) -> f64 {
let Some(c_x_psi_beta) = self.c_x_psi_beta.as_ref() else {
return 0.0;
};
x_v.iter()
.zip(x_u.iter())
.zip(c_x_psi_beta.iter())
.map(|((&xv, &xu), &c)| xv * c * xu)
.sum()
}
pub fn bilinear_with_shared_x(
&self,
x_vec: &Array1<f64>,
y_vec: &Array1<f64>,
z: &Array1<f64>,
u: &Array1<f64>,
) -> f64 {
let dx_z = self
.implicit_deriv
.forward_mul(self.axis, &z.view())
.expect("radial scalar evaluation failed during implicit hyper forward_mul");
let dx_u = self
.implicit_deriv
.forward_mul(self.axis, &u.view())
.expect("radial scalar evaluation failed during implicit hyper forward_mul");
let mut design = 0.0f64;
let w = &*self.w_diag;
for i in 0..x_vec.len() {
let wi = w[i];
design += dx_z[i] * wi * y_vec[i];
design += dx_u[i] * wi * x_vec[i];
}
if let Some(c_x_psi_beta) = self.c_x_psi_beta.as_ref() {
let c = c_x_psi_beta.as_ref();
for i in 0..x_vec.len() {
design += y_vec[i] * c[i] * x_vec[i];
}
}
let penalty = dense_bilinear(&self.s_psi, z.view(), u.view());
design + penalty
}
pub fn matvec_with_shared_xz_into(
&self,
x_vec: ArrayView1<'_, f64>,
z: ArrayView1<'_, f64>,
mut out: ArrayViewMut1<'_, f64>,
mut n_work: ArrayViewMut1<'_, f64>,
mut p_work: ArrayViewMut1<'_, f64>,
) {
assert_eq!(z.len(), self.p);
assert_eq!(out.len(), self.p);
assert_eq!(n_work.len(), self.w_diag.len());
assert_eq!(p_work.len(), self.p);
let w = &*self.w_diag;
for i in 0..w.len() {
n_work[i] = w[i] * x_vec[i];
}
let term1 = self
.implicit_deriv
.transpose_mul(self.axis, &n_work.view())
.expect("radial scalar evaluation failed during implicit hyper transpose_mul");
out.assign(&term1);
let dx_z = self
.implicit_deriv
.forward_mul(self.axis, &z)
.expect("radial scalar evaluation failed during implicit hyper forward_mul");
for i in 0..w.len() {
n_work[i] = w[i] * dx_z[i];
}
design_matrix_transpose_apply_view_into(&self.x_design, n_work.view(), p_work.view_mut());
out += &p_work;
dense_matvec_into(&self.s_psi, z, p_work.view_mut());
out += &p_work;
if let Some(c_x_psi_beta) = self.c_x_psi_beta.as_ref() {
let c = c_x_psi_beta.as_ref();
for i in 0..w.len() {
n_work[i] = c[i] * x_vec[i];
}
design_matrix_transpose_apply_view_into(
&self.x_design,
n_work.view(),
p_work.view_mut(),
);
out += &p_work;
}
}
}
pub struct SparseDirectionalHyperOperator {
pub(crate) x_tau: super::super::HyperDesignDerivative,
pub(crate) x_design: DesignMatrix,
pub(crate) w_diag: crate::matrix::SignedWeightsArc,
pub(crate) s_tau: Array2<f64>,
pub(crate) c_x_tau_beta: Option<Array1<f64>>,
pub(crate) firth_hphi_tau_partial: Option<Array2<f64>>,
pub(crate) p: usize,
}
impl HyperOperator for SparseDirectionalHyperOperator {
fn dim(&self) -> usize {
self.p
}
fn mul_vec(&self, v: &Array1<f64>) -> Array1<f64> {
assert_eq!(v.len(), self.p);
let x_v = self.x_design.matrixvectormultiply(v);
let w_x_v = &*self.w_diag * &x_v;
let term1 = self
.x_tau
.transpose_mul_original(&w_x_v)
.expect("SparseDirectionalHyperOperator transpose product should be shape-consistent");
let x_tau_v = self
.x_tau
.forward_mul_original(v)
.expect("SparseDirectionalHyperOperator forward product should be shape-consistent");
let w_x_tau_v = &*self.w_diag * &x_tau_v;
let term2 = self.x_design.transpose_vector_multiply(&w_x_tau_v);
let term3 = self.s_tau.dot(v);
let mut out = term1 + term2 + term3;
if let Some(c_x_tau_beta) = self.c_x_tau_beta.as_ref() {
let weighted = c_x_tau_beta * &x_v;
out += &self.x_design.transpose_vector_multiply(&weighted);
}
if let Some(hphi_tau_partial) = self.firth_hphi_tau_partial.as_ref() {
out -= &hphi_tau_partial.dot(v);
}
out
}
fn is_implicit(&self) -> bool {
false
}
fn as_sparse_directional(&self) -> Option<&SparseDirectionalHyperOperator> {
Some(self)
}
}
pub struct GlmCurvatureCorrectionOperator {
pub(crate) x_design: DesignMatrix,
pub(crate) neg_c_xv: Array1<f64>,
pub(crate) p: usize,
}
impl HyperOperator for GlmCurvatureCorrectionOperator {
fn dim(&self) -> usize {
self.p
}
fn mul_vec(&self, v: &Array1<f64>) -> Array1<f64> {
assert_eq!(v.len(), self.p);
let x_v = self.x_design.matrixvectormultiply(v);
let weighted = &self.neg_c_xv * &x_v;
self.x_design.transpose_vector_multiply(&weighted)
}
fn is_implicit(&self) -> bool {
false
}
}