use super::*;
pub(crate) fn ctn_penalty_scale_log_lambdas(
penalties: &[PenaltyMatrix],
likelihood_gram: &Array2<f64>,
) -> Array1<f64> {
if penalties.is_empty() {
return Array1::zeros(0);
}
let likelihood_scale = matrix_diag_mean_abs(likelihood_gram).max(CTN_SEED_SCALE_FLOOR);
Array1::from_iter(penalties.iter().map(|penalty| {
let penalty_scale = penalty_diag_scale(penalty).max(CTN_SEED_SCALE_FLOOR);
(likelihood_scale / penalty_scale)
.ln()
.clamp(CTN_SEED_LOG_LAMBDA_MIN, CTN_SEED_LOG_LAMBDA_MAX)
}))
}
pub(crate) fn penalty_diag_scale(penalty: &PenaltyMatrix) -> f64 {
match penalty {
PenaltyMatrix::Dense(matrix) => {
matrix_diag_mean_abs(matrix).max(matrix_frobenius_rms(matrix))
}
PenaltyMatrix::KroneckerFactored { left, right } => {
let diag_scale = matrix_diag_mean_abs(left) * matrix_diag_mean_abs(right);
let frob_scale = matrix_frobenius_rms(left) * matrix_frobenius_rms(right);
diag_scale.max(frob_scale)
}
PenaltyMatrix::Blockwise { local, .. } => {
matrix_diag_mean_abs(local).max(matrix_frobenius_rms(local))
}
PenaltyMatrix::Labeled { inner, .. } => penalty_diag_scale(inner),
PenaltyMatrix::Fixed { inner, .. } => penalty_diag_scale(inner),
}
}
pub(crate) fn matrix_diag_mean_abs(matrix: &Array2<f64>) -> f64 {
let d = matrix.nrows().min(matrix.ncols());
if d == 0 {
return 0.0;
}
matrix.diag().iter().map(|v| v.abs()).sum::<f64>() / d as f64
}
pub(crate) fn matrix_frobenius_rms(matrix: &Array2<f64>) -> f64 {
let d = matrix.nrows().max(1).min(matrix.ncols().max(1));
(matrix.iter().map(|v| v * v).sum::<f64>() / d as f64).sqrt()
}
pub(crate) fn factored_weighted_cross(
a: &Array2<f64>,
b: &Array2<f64>,
weights: ndarray::ArrayView1<'_, f64>,
c: &Array2<f64>,
d: &Array2<f64>,
policy: &ResourcePolicy,
) -> Result<Array2<f64>, String> {
let n = weights.len();
if a.nrows() != n || b.nrows() != n || c.nrows() != n || d.nrows() != n {
return Err(TransformationNormalError::InvalidInput {
reason: format!(
"factored_weighted_cross row mismatch: weights={n}, a={}, b={}, c={}, d={}",
a.nrows(),
b.nrows(),
c.nrows(),
d.nrows()
),
}
.into());
}
let pa = a.ncols();
let pc = c.ncols();
let pb = b.ncols();
let pd = d.ncols();
let mut out = Array2::<f64>::zeros((pa * pb, pc * pd));
use gam_problem::with_nested_parallel;
use rayon::iter::{IndexedParallelIterator, IntoParallelIterator, ParallelIterator};
out.axis_chunks_iter_mut(ndarray::Axis(0), pb.max(1))
.into_par_iter()
.enumerate()
.for_each(|(ia, mut row_band)| {
with_nested_parallel(|| {
let a_col = a.column(ia);
let mut pair_weights = Array1::<f64>::zeros(n);
for ic in 0..pc {
let c_col = c.column(ic);
for r in 0..n {
pair_weights[r] = weights[r] * a_col[r] * c_col[r];
}
let block = chunked_weighted_bt_d(b, pair_weights.view(), d, policy);
row_band
.slice_mut(s![.., ic * pd..(ic + 1) * pd])
.assign(&block);
}
});
});
Ok(out)
}
pub(crate) fn chunked_weighted_bt_d(
b: &Array2<f64>,
weights: ndarray::ArrayView1<'_, f64>,
d: &Array2<f64>,
policy: &ResourcePolicy,
) -> Array2<f64> {
use gam_linalg::faer_ndarray::{FaerArrayView, array2_to_matmut, matmul_parallelism};
use faer::Accum;
use faer::linalg::matmul::matmul;
let n = weights.len();
let pb = b.ncols();
let pd = d.ncols();
let rows_per_chunk =
gam_runtime::resource::rows_for_target_bytes(policy.row_chunk_target_bytes, pb + pd);
let mut out = Array2::<f64>::zeros((pb, pd));
if n == 0 || pb == 0 || pd == 0 {
return out;
}
let mut out_view = array2_to_matmut(&mut out);
let mut dw_buf = Array2::<f64>::zeros((rows_per_chunk.min(n), pd));
for start in (0..n).step_by(rows_per_chunk) {
let end = (start + rows_per_chunk).min(n);
let rows = end - start;
let bl = b.slice(s![start..end, ..]);
let dl = d.slice(s![start..end, ..]);
{
let mut dw_slice = dw_buf.slice_mut(s![..rows, ..]);
for local in 0..rows {
let w = weights[start + local];
let drow = dl.row(local);
let mut wrow = dw_slice.row_mut(local);
ndarray::Zip::from(&mut wrow)
.and(&drow)
.for_each(|dst, &src| *dst = w * src);
}
}
let bl_view = FaerArrayView::new(&bl);
let dw_slice = dw_buf.slice(s![..rows, ..]);
let dw_view = FaerArrayView::new(&dw_slice);
let par = matmul_parallelism(pb, pd, rows);
matmul(
out_view.as_mut(),
Accum::Add,
bl_view.as_ref().transpose(),
dw_view.as_ref(),
1.0,
par,
);
}
out
}
pub(crate) fn chunked_weighted_bt_d_designmatrix(
b: &DesignMatrix,
weights: ndarray::ArrayView1<'_, f64>,
d: &DesignMatrix,
policy: &ResourcePolicy,
) -> Result<Array2<f64>, String> {
use gam_linalg::faer_ndarray::{FaerArrayView, array2_to_matmut, matmul_parallelism};
use faer::Accum;
use faer::linalg::matmul::matmul;
let n = weights.len();
let pb = b.ncols();
let pd = d.ncols();
let rows_per_chunk =
gam_runtime::resource::rows_for_target_bytes(policy.row_chunk_target_bytes, pb + pd);
let mut out = Array2::<f64>::zeros((pb, pd));
if n == 0 || pb == 0 || pd == 0 {
return Ok(out);
}
let mut out_view = array2_to_matmut(&mut out);
for start in (0..n).step_by(rows_per_chunk) {
let end = (start + rows_per_chunk).min(n);
let rows = end - start;
let bl = b.try_row_chunk(start..end).map_err(|e| e.to_string())?;
let mut dw = d.try_row_chunk(start..end).map_err(|e| e.to_string())?;
for local in 0..rows {
let w = weights[start + local];
if w != 1.0 {
let mut wrow = dw.row_mut(local);
wrow.mapv_inplace(|v| w * v);
}
}
let bl_view = FaerArrayView::new(&bl);
let dw_view = FaerArrayView::new(&dw);
let par = matmul_parallelism(pb, pd, rows);
matmul(
out_view.as_mut(),
Accum::Add,
bl_view.as_ref().transpose(),
dw_view.as_ref(),
1.0,
par,
);
}
Ok(out)
}