1use super::reml_outer_engine::{
11 BarrierConfig, ContractedPsiSecondOrderFn, DispersionHandling, EvalMode, FixedDriftDerivFn,
12 HessianDerivativeProvider, HessianOperator, HyperCoord, HyperCoordPair, InnerSolution,
13 InnerSolutionBuilder, PenaltyCoordinate, PenaltyLogdetDerivs, PenaltySubspaceTrace,
14 RemlLamlResult, penalty_matrix_root, reml_laml_evaluate,
15};
16use gam_linalg::faer_ndarray::fast_xt_diag_y;
17use crate::model_types::ProjectedKktResidual;
18use ndarray::{Array1, Array2};
19use rayon::iter::{
20 IndexedParallelIterator, IntoParallelIterator, IntoParallelRefIterator, ParallelIterator,
21};
22use rayon::slice::ParallelSliceMut;
23use std::sync::Arc;
24
25pub(crate) const DENSE_WEIGHTED_PRODUCT_PAR_FLOPS: usize = 8_000_000;
34pub(crate) const DENSE_ROW_SCALE_PAR_CELLS: usize = 64 * 1024;
35
36#[derive(Clone, Copy)]
37pub(crate) enum DenseRowScaleMode {
38 Direct,
39 InversePositiveOrZero,
40}
41
42#[inline]
43pub(crate) fn dense_weighted_chunk_rows(cols: usize) -> usize {
44 const TARGET_BYTES: usize = 2 * 1024 * 1024;
45 const MIN_ROWS: usize = 256;
46 const MAX_ROWS: usize = 4096;
47 let bytes_per_row = cols.max(1) * std::mem::size_of::<f64>();
48 (TARGET_BYTES / bytes_per_row).clamp(MIN_ROWS, MAX_ROWS)
49}
50
51pub(crate) fn row_scale_dense_into(x: &Array2<f64>, scale: &Array1<f64>, out: &mut Array2<f64>) {
58 assert_eq!(x.nrows(), scale.len(), "scale length must match row count");
59 if out.raw_dim() != x.raw_dim() {
60 *out = Array2::<f64>::zeros(x.raw_dim());
61 }
62 out.assign(x);
63 row_scale_dense_in_place(out, scale, DenseRowScaleMode::Direct);
64}
65
66pub(crate) fn row_scale_dense_in_place_by_inverse_positive_or_zero(
69 out: &mut Array2<f64>,
70 scale: &Array1<f64>,
71) {
72 row_scale_dense_in_place(out, scale, DenseRowScaleMode::InversePositiveOrZero);
73}
74
75pub(crate) fn row_scale_dense_in_place(
76 out: &mut Array2<f64>,
77 scale: &Array1<f64>,
78 mode: DenseRowScaleMode,
79) {
80 assert_eq!(
81 out.nrows(),
82 scale.len(),
83 "scale length must match row count"
84 );
85 let ncols = out.ncols();
86 if ncols == 0 {
87 return;
88 }
89
90 let cells = out.nrows().saturating_mul(ncols);
91 if cells >= DENSE_ROW_SCALE_PAR_CELLS
92 && rayon::current_num_threads() > 1
93 && out.is_standard_layout()
94 && let Some(slice) = out.as_slice_memory_order_mut()
95 {
96 slice
97 .par_chunks_mut(ncols)
98 .zip(
99 scale
100 .as_slice()
101 .expect("Array1 must be contiguous")
102 .par_iter(),
103 )
104 .for_each(|(row_values, &w)| scale_dense_row_values(row_values, w, mode));
105 return;
106 }
107
108 ndarray::Zip::from(out.rows_mut())
109 .and(scale.view())
110 .for_each(|mut row, &w| {
111 if let Some(row_values) = row.as_slice_mut() {
112 scale_dense_row_values(row_values, w, mode);
113 } else {
114 match mode {
115 DenseRowScaleMode::Direct => row *= w,
116 DenseRowScaleMode::InversePositiveOrZero => {
117 if w > 0.0 {
118 row *= w.recip();
119 } else {
120 row.fill(0.0);
121 }
122 }
123 }
124 }
125 });
126}
127
128#[inline]
129pub(crate) fn scale_dense_row_values(row_values: &mut [f64], scale: f64, mode: DenseRowScaleMode) {
130 match mode {
131 DenseRowScaleMode::Direct => {
132 for value in row_values {
133 *value *= scale;
134 }
135 }
136 DenseRowScaleMode::InversePositiveOrZero => {
137 if scale > 0.0 {
138 let inv = scale.recip();
139 for value in row_values {
140 *value *= inv;
141 }
142 } else {
143 for value in row_values {
144 *value = 0.0;
145 }
146 }
147 }
148 }
149}
150
151pub(crate) fn accumulate_weighted_cross_rows(
152 out: &mut Array2<f64>,
153 left: &Array2<f64>,
154 right: &Array2<f64>,
155 weights: &Array1<f64>,
156 row_start: usize,
157 row_end: usize,
158) {
159 let p = left.ncols();
160 let q = right.ncols();
161 for i in row_start..row_end {
162 let wi = weights[i];
163 if wi == 0.0 {
164 continue;
165 }
166 for a in 0..p {
167 let scaled = wi * left[[i, a]];
168 if scaled == 0.0 {
169 continue;
170 }
171 for b in 0..q {
172 out[[a, b]] += scaled * right[[i, b]];
173 }
174 }
175 }
176}
177
178pub(crate) fn accumulate_xt_diag_x_upper_rows(
179 out: &mut Array2<f64>,
180 x: &Array2<f64>,
181 diag: &Array1<f64>,
182 row_start: usize,
183 row_end: usize,
184) {
185 let p = x.ncols();
186 for i in row_start..row_end {
187 let wi = diag[i];
188 if wi == 0.0 {
189 continue;
190 }
191 for a in 0..p {
192 let scaled = wi * x[[i, a]];
193 if scaled == 0.0 {
194 continue;
195 }
196 for b in a..p {
197 out[[a, b]] += scaled * x[[i, b]];
198 }
199 }
200 }
201}
202
203pub(crate) fn weighted_cross_dense(
208 left: &Array2<f64>,
209 right: &Array2<f64>,
210 weights: &Array1<f64>,
211) -> Array2<f64> {
212 assert_eq!(left.nrows(), right.nrows());
213 assert_eq!(left.nrows(), weights.len());
214 let n = weights.len();
215 let p = left.ncols();
216 let q = right.ncols();
217 if n == 0 || p == 0 || q == 0 {
218 return Array2::<f64>::zeros((p, q));
219 }
220
221 let work = n.saturating_mul(p).saturating_mul(q);
222 if rayon::current_num_threads() <= 1 || work < DENSE_WEIGHTED_PRODUCT_PAR_FLOPS {
223 return fast_xt_diag_y(left, weights, right);
224 }
225
226 let chunk_rows = crate::parallel_strategy::row_reduction_chunk_rows(
227 n,
228 p.saturating_mul(q),
229 p.saturating_mul(q),
230 DENSE_WEIGHTED_PRODUCT_PAR_FLOPS,
231 )
232 .unwrap_or_else(|| dense_weighted_chunk_rows(p + q).min(n));
233 let chunks = n.div_ceil(chunk_rows);
234 (0..chunks)
235 .into_par_iter()
236 .fold(
237 || Array2::<f64>::zeros((p, q)),
238 |mut local, chunk| {
239 let start = chunk * chunk_rows;
240 let end = (start + chunk_rows).min(n);
241 accumulate_weighted_cross_rows(&mut local, left, right, weights, start, end);
242 local
243 },
244 )
245 .reduce(
246 || Array2::<f64>::zeros((p, q)),
247 |mut a, b| {
248 a += &b;
249 a
250 },
251 )
252}
253
254pub(crate) fn xt_diag_x_dense_into(
259 x: &Array2<f64>,
260 diag: &Array1<f64>,
261 weighted: &mut Array2<f64>,
262) -> Array2<f64> {
263 let (n, p) = x.dim();
264 assert_eq!(diag.len(), n, "diag length must match row count");
265 if n == 0 || p == 0 {
266 return Array2::<f64>::zeros((p, p));
267 }
268
269 let work = n.saturating_mul(p).saturating_mul(p);
270 if rayon::current_num_threads() <= 1 || work < DENSE_WEIGHTED_PRODUCT_PAR_FLOPS {
271 row_scale_dense_into(x, diag, weighted);
272 return gam_linalg::faer_ndarray::fast_atb(x, weighted);
273 }
274
275 let chunk_rows = crate::parallel_strategy::row_reduction_chunk_rows(
276 n,
277 p.saturating_mul(p),
278 p.saturating_mul(p),
279 DENSE_WEIGHTED_PRODUCT_PAR_FLOPS,
280 )
281 .unwrap_or_else(|| dense_weighted_chunk_rows(p).min(n));
282 let chunks = n.div_ceil(chunk_rows);
283 let mut out = (0..chunks)
284 .into_par_iter()
285 .fold(
286 || Array2::<f64>::zeros((p, p)),
287 |mut local, chunk| {
288 let start = chunk * chunk_rows;
289 let end = (start + chunk_rows).min(n);
290 accumulate_xt_diag_x_upper_rows(&mut local, x, diag, start, end);
291 local
292 },
293 )
294 .reduce(
295 || Array2::<f64>::zeros((p, p)),
296 |mut a, b| {
297 a += &b;
298 a
299 },
300 );
301 for a in 0..p {
302 for b in 0..a {
303 out[[a, b]] = out[[b, a]];
304 }
305 }
306 out
307}
308
309pub struct InnerAssembly<'dp> {
319 pub log_likelihood: f64,
321 pub penalty_quadratic: f64,
322 pub beta: Array1<f64>,
323 pub n_observations: usize,
324 pub hessian_op: std::sync::Arc<dyn HessianOperator>,
325 pub penalty_coords: Vec<PenaltyCoordinate>,
326 pub penalty_logdet: PenaltyLogdetDerivs,
327 pub dispersion: DispersionHandling,
328 pub rho_curvature_scale: f64,
329 pub rho_prior: gam_problem::RhoPrior,
330 pub hessian_logdet_correction: f64,
331 pub penalty_subspace_trace: Option<Arc<PenaltySubspaceTrace>>,
332
333 pub deriv_provider: Option<Box<dyn HessianDerivativeProvider + 'dp>>,
335 pub firth: Option<crate::estimate::reml::reml_outer_engine::ExactJeffreysTerm>,
341 pub nullspace_dim: Option<f64>,
342 pub barrier_config: Option<BarrierConfig>,
343 pub kkt_residual: Option<ProjectedKktResidual>,
344 pub active_constraints: Option<Arc<crate::model_types::ActiveLinearConstraintBlock>>,
349
350 pub ext_coords: Vec<HyperCoord>,
352 pub ext_coord_pair_fn: Option<Box<dyn Fn(usize, usize) -> HyperCoordPair + Send + Sync>>,
353 pub rho_ext_pair_fn: Option<Box<dyn Fn(usize, usize) -> HyperCoordPair + Send + Sync>>,
354 pub fixed_drift_deriv: Option<FixedDriftDerivFn>,
355 pub contracted_psi_second_order: Option<ContractedPsiSecondOrderFn>,
359}
360
361impl<'dp> InnerAssembly<'dp> {
362 pub fn build(self) -> InnerSolution<'dp> {
364 let mut builder = InnerSolutionBuilder::new(
365 self.log_likelihood,
366 self.penalty_quadratic,
367 self.beta,
368 self.n_observations,
369 self.hessian_op,
370 self.penalty_coords,
371 self.penalty_logdet,
372 self.dispersion,
373 );
374 builder = builder.rho_curvature_scale(self.rho_curvature_scale);
375 builder = builder.rho_prior(self.rho_prior);
376 builder = builder.hessian_logdet_correction(self.hessian_logdet_correction);
377 builder = builder.penalty_subspace_trace(self.penalty_subspace_trace);
378
379 if let Some(dp) = self.deriv_provider {
380 builder = builder.deriv_provider(dp);
381 }
382 builder = builder.firth_term(self.firth);
383 if let Some(nd) = self.nullspace_dim {
384 builder = builder.nullspace_dim_override(nd);
385 }
386 builder = builder.barrier_config(self.barrier_config);
387 builder = builder.kkt_residual(self.kkt_residual);
388 builder = builder.active_constraints(self.active_constraints);
389
390 if !self.ext_coords.is_empty() {
391 builder = builder.ext_coords(self.ext_coords);
392 }
393 if let Some(f) = self.ext_coord_pair_fn {
394 builder = builder.ext_coord_pair_fn(f);
395 }
396 if let Some(f) = self.rho_ext_pair_fn {
397 builder = builder.rho_ext_pair_fn(f);
398 }
399 if let Some(f) = self.fixed_drift_deriv {
400 builder = builder.fixed_drift_deriv(f);
401 }
402 builder = builder.contracted_psi_second_order(self.contracted_psi_second_order);
403
404 builder.build()
405 }
406
407 pub fn evaluate(
409 self,
410 rho: &[f64],
411 mode: EvalMode,
412 prior: Option<(f64, Array1<f64>, Option<Array2<f64>>)>,
413 ) -> Result<RemlLamlResult, String> {
414 let solution = self.build();
415 reml_laml_evaluate(&solution, rho, mode, prior)
416 }
417}
418
419pub fn evaluate_solution(
425 solution: &InnerSolution<'_>,
426 rho: &[f64],
427 mode: EvalMode,
428 prior: Option<(f64, Array1<f64>, Option<Array2<f64>>)>,
429) -> Result<RemlLamlResult, String> {
430 reml_laml_evaluate(solution, rho, mode, prior)
431}
432
433pub struct PenaltyBlockDesc<'a> {
439 pub matrix: &'a Array2<f64>,
440 pub range_start: usize,
441 pub range_end: usize,
442}
443
444pub fn penalty_coords_from_blocks(
449 blocks: &[PenaltyBlockDesc],
450 total_dim: usize,
451) -> Result<Vec<PenaltyCoordinate>, String> {
452 blocks
453 .iter()
454 .map(|b| {
455 let root = penalty_matrix_root(b.matrix)?;
456 Ok(PenaltyCoordinate::from_block_root(
457 root,
458 b.range_start,
459 b.range_end,
460 total_dim,
461 ))
462 })
463 .collect()
464}
465
466#[cfg(test)]
467mod tests {
468 use super::*;
469 use approx::assert_relative_eq;
470 use ndarray::Array2;
471
472 pub(crate) fn assert_matrix_close(
473 got: &Array2<f64>,
474 expected: &Array2<f64>,
475 epsilon: f64,
476 max_relative: f64,
477 ) {
478 assert_eq!(got.dim(), expected.dim());
479 for ((i, j), &value) in got.indexed_iter() {
480 assert_relative_eq!(
481 value,
482 expected[[i, j]],
483 epsilon = epsilon,
484 max_relative = max_relative
485 );
486 }
487 }
488
489 pub(crate) fn deterministic_matrix(n: usize, p: usize, phase: f64) -> Array2<f64> {
490 Array2::from_shape_fn((n, p), |(i, j)| {
491 let a = ((i as f64 + 1.0) * (j as f64 + 3.0) + phase).sin();
492 let b = ((i as f64 + 5.0) / (j as f64 + 2.0) + phase).cos();
493 0.25 * a + 0.75 * b
494 })
495 }
496
497 pub(crate) fn deterministic_weights(n: usize) -> Array1<f64> {
498 Array1::from_shape_fn(n, |i| {
499 if i % 17 == 0 {
500 0.0
501 } else {
502 0.2 + ((i as f64 + 1.0) * 0.013).sin().abs()
503 }
504 })
505 }
506
507 pub(crate) fn weighted_cross_reference(
508 left: &Array2<f64>,
509 right: &Array2<f64>,
510 weights: &Array1<f64>,
511 ) -> Array2<f64> {
512 let mut out = Array2::<f64>::zeros((left.ncols(), right.ncols()));
513 for i in 0..weights.len() {
514 for a in 0..left.ncols() {
515 let scaled = weights[i] * left[[i, a]];
516 for b in 0..right.ncols() {
517 out[[a, b]] += scaled * right[[i, b]];
518 }
519 }
520 }
521 out
522 }
523
524 #[test]
525 pub(crate) fn row_scale_dense_into_reuses_buffer_and_matches_reference() {
526 let x = deterministic_matrix(37, 11, 0.3);
527 let weights = deterministic_weights(x.nrows());
528 let mut out = Array2::<f64>::zeros(x.raw_dim());
529 let ptr = out.as_ptr();
530 row_scale_dense_into(&x, &weights, &mut out);
531 assert_eq!(out.as_ptr(), ptr);
532 for i in 0..x.nrows() {
533 for j in 0..x.ncols() {
534 assert_relative_eq!(out[[i, j]], x[[i, j]] * weights[i], epsilon = 0.0);
535 }
536 }
537 }
538
539 #[test]
540 pub(crate) fn weighted_cross_dense_matches_rowwise_reference_at_large_scale_block_size() {
541 let left = deterministic_matrix(2048, 96, 0.1);
542 let right = deterministic_matrix(2048, 64, 0.7);
543 let weights = deterministic_weights(left.nrows());
544 let got = weighted_cross_dense(&left, &right, &weights);
545 let expected = weighted_cross_reference(&left, &right, &weights);
546 assert_matrix_close(&got, &expected, 5e-10, 5e-12);
547 }
548
549 #[test]
550 pub(crate) fn xt_diag_x_dense_into_matches_symmetric_reference_at_large_scale_block_size() {
551 let x = deterministic_matrix(1024, 96, 1.1);
552 let weights = deterministic_weights(x.nrows());
553 let mut scratch = Array2::<f64>::zeros((0, 0));
554 let got = xt_diag_x_dense_into(&x, &weights, &mut scratch);
555 let expected = weighted_cross_reference(&x, &x, &weights);
556 assert_matrix_close(&got, &expected, 3e-10, 5e-12);
557 for i in 0..got.nrows() {
558 for j in 0..got.ncols() {
559 assert_relative_eq!(got[[i, j]], got[[j, i]], epsilon = 0.0);
560 }
561 }
562 }
563}