1use super::*;
2
3pub(crate) fn as_implicit(op: &dyn HyperOperator) -> Option<&ImplicitHyperOperator> {
4 op.as_any().downcast_ref::<ImplicitHyperOperator>()
5}
6
7pub(crate) fn as_composite(op: &dyn HyperOperator) -> Option<&CompositeHyperOperator> {
8 op.as_any().downcast_ref::<CompositeHyperOperator>()
9}
10
11pub(crate) fn as_weighted(op: &dyn HyperOperator) -> Option<&WeightedHyperOperator> {
12 op.as_any().downcast_ref::<WeightedHyperOperator>()
13}
14
15pub(crate) trait DriftDerivTraceExt {
16 fn trace_logdet(&self, hop: &dyn HessianOperator) -> f64;
17
18 fn trace_logdet_hessian_cross(&self, rhs: &Self, hop: &dyn HessianOperator) -> f64;
19}
20
21impl DriftDerivTraceExt for DriftDerivResult {
22 fn trace_logdet(&self, hop: &dyn HessianOperator) -> f64 {
23 match self {
24 Self::Dense(matrix) => hop.trace_logdet_gradient(matrix),
25 Self::Operator(operator) => hop.trace_logdet_operator(operator.as_ref()),
26 }
27 }
28
29 fn trace_logdet_hessian_cross(&self, rhs: &Self, hop: &dyn HessianOperator) -> f64 {
30 match (self, rhs) {
31 (Self::Dense(left), Self::Dense(right)) => hop.trace_logdet_hessian_cross(left, right),
32 (Self::Dense(left), Self::Operator(right)) => {
33 hop.trace_logdet_hessian_cross_matrix_operator(left, right.as_ref())
34 }
35 (Self::Operator(left), Self::Dense(right)) => {
36 hop.trace_logdet_hessian_cross_matrix_operator(right, left.as_ref())
37 }
38 (Self::Operator(left), Self::Operator(right)) => {
39 hop.trace_logdet_hessian_cross_operator(left.as_ref(), right.as_ref())
40 }
41 }
42 }
43}
44
45#[derive(Clone)]
46pub struct CompositeHyperOperator {
47 pub dense: Option<Array2<f64>>,
48 pub operators: Vec<Arc<dyn HyperOperator>>,
49 pub dim_hint: usize,
50}
51
52pub(crate) fn composite_trace_implicit_batched(
60 operators: &[Arc<dyn HyperOperator>],
61 factor: &Array2<f64>,
62 cache: Option<&ProjectedFactorCache>,
63) -> f64 {
64 let mut trace = 0.0;
65 let mut group_starts: Vec<Vec<usize>> = Vec::new();
66 let mut handled = vec![false; operators.len()];
67
68 for (i, op) in operators.iter().enumerate() {
69 if handled[i] {
70 continue;
71 }
72 let Some(impl_i) = as_implicit(op.as_ref()) else {
73 continue;
74 };
75 let mut group = vec![i];
76 handled[i] = true;
77 for j in (i + 1)..operators.len() {
78 if handled[j] {
79 continue;
80 }
81 if let Some(impl_j) = as_implicit(operators[j].as_ref())
82 && Arc::ptr_eq(&impl_i.implicit_deriv, &impl_j.implicit_deriv)
83 && Arc::ptr_eq(&impl_i.x_design, &impl_j.x_design)
84 && Arc::ptr_eq(impl_i.w_diag.as_arc(), impl_j.w_diag.as_arc())
85 && impl_i.p == impl_j.p
86 {
87 group.push(j);
88 handled[j] = true;
89 }
90 }
91 group_starts.push(group);
92 }
93
94 for group in &group_starts {
95 if group.len() >= 2 {
96 let lead = as_implicit(operators[group[0]].as_ref()).unwrap();
97 let xf = match cache {
98 Some(c) => lead.cached_xf(factor, c),
99 None => Arc::new(lead.compute_xf(factor)),
100 };
101 let axes: Vec<(usize, &Array2<f64>, Option<&Array1<f64>>)> = group
102 .iter()
103 .map(|&k| {
104 let op = as_implicit(operators[k].as_ref()).unwrap();
105 (op.axis, &op.s_psi, op.c_x_psi_beta.as_deref())
106 })
107 .collect();
108 let values = lead.trace_projected_factor_all_axes_with_xf(factor, xf.view(), &axes);
109 trace += values.iter().sum::<f64>();
110 } else {
111 let op = &operators[group[0]];
112 trace += match cache {
113 Some(c) => op.trace_projected_factor_cached(factor, c),
114 None => op.trace_projected_factor(factor),
115 };
116 }
117 }
118
119 for (i, op) in operators.iter().enumerate() {
120 if handled[i] {
121 continue;
122 }
123 trace += match cache {
124 Some(c) => op.trace_projected_factor_cached(factor, c),
125 None => op.trace_projected_factor(factor),
126 };
127 }
128
129 trace
130}
131
132pub(crate) fn trace_projected_factors_batched(
137 operators: &[Arc<dyn HyperOperator>],
138 factor: &Array2<f64>,
139 cache: &ProjectedFactorCache,
140) -> Vec<f64> {
141 let mut out = vec![0.0; operators.len()];
142 let mut handled = vec![false; operators.len()];
143
144 for i in 0..operators.len() {
145 if handled[i] {
146 continue;
147 }
148 let Some(impl_i) = as_implicit(operators[i].as_ref()) else {
149 out[i] = operators[i].trace_projected_factor_cached(factor, cache);
150 handled[i] = true;
151 continue;
152 };
153
154 let mut group = vec![i];
155 handled[i] = true;
156 for j in (i + 1)..operators.len() {
157 if handled[j] {
158 continue;
159 }
160 if let Some(impl_j) = as_implicit(operators[j].as_ref())
161 && Arc::ptr_eq(&impl_i.implicit_deriv, &impl_j.implicit_deriv)
162 && Arc::ptr_eq(&impl_i.x_design, &impl_j.x_design)
163 && Arc::ptr_eq(impl_i.w_diag.as_arc(), impl_j.w_diag.as_arc())
164 && impl_i.p == impl_j.p
165 {
166 group.push(j);
167 handled[j] = true;
168 }
169 }
170
171 if group.len() >= 2 {
172 let xf = impl_i.cached_xf(factor, cache);
173 let axes: Vec<(usize, &Array2<f64>, Option<&Array1<f64>>)> = group
174 .iter()
175 .map(|&idx| {
176 let op = as_implicit(operators[idx].as_ref()).unwrap();
177 (op.axis, &op.s_psi, op.c_x_psi_beta.as_deref())
178 })
179 .collect();
180 let values = impl_i.trace_projected_factor_all_axes_with_xf(factor, xf.view(), &axes);
181 for (&idx, value) in group.iter().zip(values) {
182 out[idx] = value;
183 }
184 } else {
185 out[i] = operators[i].trace_projected_factor_cached(factor, cache);
186 }
187 }
188
189 out
190}
191
192pub(crate) fn collect_projected_trace_terms<'a>(
193 out_idx: usize,
194 weight: f64,
195 op: &'a dyn HyperOperator,
196 factor: &Array2<f64>,
197 dense_acc: &mut [f64],
198 terms: &mut Vec<(usize, f64, &'a dyn HyperOperator)>,
199) {
200 if weight == 0.0 {
201 return;
202 }
203 if let Some(composite) = as_composite(op) {
204 if let Some(dense) = composite.dense.as_ref() {
205 dense_acc[out_idx] += weight * dense_trace_projected_factor(dense, factor);
206 }
207 for inner in &composite.operators {
208 collect_projected_trace_terms(
209 out_idx,
210 weight,
211 inner.as_ref(),
212 factor,
213 dense_acc,
214 terms,
215 );
216 }
217 } else if let Some(weighted) = as_weighted(op) {
218 for (term_weight, inner) in &weighted.terms {
219 collect_projected_trace_terms(
220 out_idx,
221 weight * *term_weight,
222 inner.as_ref(),
223 factor,
224 dense_acc,
225 terms,
226 );
227 }
228 } else {
229 terms.push((out_idx, weight, op));
230 }
231}
232
233pub(crate) fn collect_projected_matrix_terms<'a>(
234 out_idx: usize,
235 weight: f64,
236 op: &'a dyn HyperOperator,
237 factor: &Array2<f64>,
238 dense_acc: &mut [Array2<f64>],
239 terms: &mut Vec<(usize, f64, &'a dyn HyperOperator)>,
240) {
241 if weight == 0.0 {
242 return;
243 }
244 if let Some(composite) = as_composite(op) {
245 if let Some(dense) = composite.dense.as_ref() {
246 dense_acc[out_idx].scaled_add(weight, &dense_projected_matrix(dense, factor));
247 }
248 for inner in &composite.operators {
249 collect_projected_matrix_terms(
250 out_idx,
251 weight,
252 inner.as_ref(),
253 factor,
254 dense_acc,
255 terms,
256 );
257 }
258 } else if let Some(weighted) = as_weighted(op) {
259 for (term_weight, inner) in &weighted.terms {
260 collect_projected_matrix_terms(
261 out_idx,
262 weight * *term_weight,
263 inner.as_ref(),
264 factor,
265 dense_acc,
266 terms,
267 );
268 }
269 } else {
270 terms.push((out_idx, weight, op));
271 }
272}
273
274pub(crate) fn trace_projected_operator_terms_batched(
275 n_out: usize,
276 terms: &[(usize, f64, &dyn HyperOperator)],
277 factor: &Array2<f64>,
278 cache: &ProjectedFactorCache,
279) -> Vec<f64> {
280 let mut out = vec![0.0_f64; n_out];
281 let mut handled = vec![false; terms.len()];
282
283 for i in 0..terms.len() {
284 if handled[i] {
285 continue;
286 }
287 let Some(impl_i) = as_implicit(terms[i].2) else {
288 continue;
289 };
290 let mut group = vec![i];
291 handled[i] = true;
292 for j in (i + 1)..terms.len() {
293 if handled[j] {
294 continue;
295 }
296 if let Some(impl_j) = as_implicit(terms[j].2)
297 && Arc::ptr_eq(&impl_i.implicit_deriv, &impl_j.implicit_deriv)
298 && Arc::ptr_eq(&impl_i.x_design, &impl_j.x_design)
299 && Arc::ptr_eq(impl_i.w_diag.as_arc(), impl_j.w_diag.as_arc())
300 && impl_i.p == impl_j.p
301 {
302 group.push(j);
303 handled[j] = true;
304 }
305 }
306
307 let lead = as_implicit(terms[group[0]].2).unwrap();
308 let xf = lead.cached_xf(factor, cache);
309 let axes: Vec<(usize, &Array2<f64>, Option<&Array1<f64>>)> = group
310 .iter()
311 .map(|&term_idx| {
312 let op = as_implicit(terms[term_idx].2).unwrap();
313 (op.axis, &op.s_psi, op.c_x_psi_beta.as_deref())
314 })
315 .collect();
316 let values = lead.trace_projected_factor_all_axes_with_xf(factor, xf.view(), &axes);
317 for (&term_idx, value) in group.iter().zip(values.iter()) {
318 let (out_idx, weight, _) = terms[term_idx];
319 out[out_idx] += weight * *value;
320 }
321 }
322
323 for (i, (out_idx, weight, op)) in terms.iter().enumerate() {
324 if handled[i] {
325 continue;
326 }
327 out[*out_idx] += *weight * op.trace_projected_factor_cached(factor, cache);
328 }
329
330 out
331}
332
333pub(crate) fn projected_operator_terms_batched(
334 n_out: usize,
335 terms: &[(usize, f64, &dyn HyperOperator)],
336 factor: &Array2<f64>,
337 cache: &ProjectedFactorCache,
338) -> Vec<Array2<f64>> {
339 let rank = factor.ncols();
340 let mut out: Vec<Array2<f64>> = (0..n_out)
341 .map(|_| Array2::<f64>::zeros((rank, rank)))
342 .collect();
343 for (out_idx, weight, op) in terms.iter() {
344 let projected = op.projected_matrix_cached(factor, cache);
345 out[*out_idx].scaled_add(*weight, &projected);
346 }
347 out
348}
349
350pub(crate) fn project_hyper_operators_batched(
351 n_out: usize,
352 terms: &[(usize, f64, &dyn HyperOperator)],
353 factor: &Array2<f64>,
354 cache: &ProjectedFactorCache,
355) -> Vec<Array2<f64>> {
356 projected_operator_terms_batched(n_out, terms, factor, cache)
357}
358
359pub(crate) fn trace_logdet_drifts_projected_factor_batched(
360 drifts: &[DriftDerivResult],
361 factor: &Array2<f64>,
362 cache: &ProjectedFactorCache,
363) -> Vec<f64> {
364 let mut out = vec![0.0_f64; drifts.len()];
365 let mut terms: Vec<(usize, f64, &dyn HyperOperator)> = Vec::new();
366 for (idx, drift) in drifts.iter().enumerate() {
367 match drift {
368 DriftDerivResult::Dense(matrix) => {
369 out[idx] += dense_trace_projected_factor(matrix, factor);
370 }
371 DriftDerivResult::Operator(op) => {
372 collect_projected_trace_terms(idx, 1.0, op.as_ref(), factor, &mut out, &mut terms);
373 }
374 }
375 }
376 let batched = trace_projected_operator_terms_batched(drifts.len(), &terms, factor, cache);
377 for (dst, value) in out.iter_mut().zip(batched) {
378 *dst += value;
379 }
380 out
381}
382
383pub(crate) fn dense_spectral_trace_logdet_drifts_batched(
384 ds: &DenseSpectralOperator,
385 drifts: &[DriftDerivResult],
386) -> Vec<f64> {
387 trace_logdet_drifts_projected_factor_batched(drifts, &ds.g_factor, &ds.projected_factor_cache)
388}
389
390pub(crate) fn penalty_subspace_trace_factor(kernel: &PenaltySubspaceTrace) -> Array2<f64> {
391 let (evals, evecs) = kernel
392 .h_proj_inverse
393 .eigh(faer::Side::Lower)
394 .expect("PenaltySubspaceTrace kernel factor eigendecomposition failed");
395 let r = evals.len();
396 let mut root = evecs.clone();
413 for col in 0..r {
414 let scale = evals[col].max(0.0).sqrt();
415 for row in 0..r {
416 root[[row, col]] *= scale;
417 }
418 }
419 gam_linalg::faer_ndarray::fast_ab(&kernel.u_s, &root)
420}
421
422pub(crate) fn penalty_subspace_trace_drifts_batched(
423 kernel: &PenaltySubspaceTrace,
424 drifts: &[DriftDerivResult],
425) -> Vec<f64> {
426 let factor = penalty_subspace_trace_factor(kernel);
427 let cache = ProjectedFactorCache::default();
428 trace_logdet_drifts_projected_factor_batched(drifts, &factor, &cache)
429}
430
431pub(crate) fn penalty_subspace_reduce_drifts_batched(
432 kernel: &PenaltySubspaceTrace,
433 drifts: &[DriftDerivResult],
434) -> Vec<Array2<f64>> {
435 drifts
436 .iter()
437 .map(|drift| match drift {
438 DriftDerivResult::Dense(matrix) => kernel.reduce(matrix),
439 DriftDerivResult::Operator(op) => kernel.reduce_operator(op.as_ref()),
449 })
450 .collect()
451}
452
453pub(crate) fn dense_spectral_trace_logdet_operators_batched(
454 ds: &DenseSpectralOperator,
455 operators: &[Arc<dyn HyperOperator>],
456) -> Vec<f64> {
457 if operators.is_empty() {
458 return Vec::new();
459 }
460 if log::log_enabled!(log::Level::Info) {
461 let start = std::time::Instant::now();
462 let out =
463 trace_projected_factors_batched(operators, &ds.g_factor, &ds.projected_factor_cache);
464 let implicit_count = operators.iter().filter(|op| op.is_implicit()).count();
465 dense_spectral_stage_log(
466 &format!(
467 "DenseSpectralOperator::trace_logdet_operators_batched dim={} rank={} ops={} implicit_ops={}",
468 ds.n_dim,
469 ds.g_factor.ncols(),
470 operators.len(),
471 implicit_count,
472 ),
473 start.elapsed().as_secs_f64(),
474 );
475 out
476 } else {
477 trace_projected_factors_batched(operators, &ds.g_factor, &ds.projected_factor_cache)
478 }
479}
480
481impl HyperOperator for CompositeHyperOperator {
482 fn as_any(&self) -> &(dyn std::any::Any + 'static) {
483 self
484 }
485
486 fn dim(&self) -> usize {
487 self.dim_hint
488 }
489
490 fn mul_vec(&self, v: &Array1<f64>) -> Array1<f64> {
491 let mut out = Array1::<f64>::zeros(v.len());
492 self.mul_vec_into(v.view(), out.view_mut());
493 out
494 }
495
496 fn mul_vec_view(&self, v: ArrayView1<'_, f64>) -> Array1<f64> {
497 let mut out = Array1::<f64>::zeros(v.len());
498 self.mul_vec_into(v, out.view_mut());
499 out
500 }
501
502 fn mul_vec_into(&self, v: ArrayView1<'_, f64>, mut out: ArrayViewMut1<'_, f64>) {
503 if self.dense.is_none() && self.operators.len() == 1 {
504 self.operators[0].mul_vec_into(v, out);
505 return;
506 }
507
508 out.fill(0.0);
509 if let Some(dense) = self.dense.as_ref() {
510 dense_matvec_into(dense, v, out.view_mut());
511 }
512 for op in &self.operators {
513 op.scaled_add_mul_vec(v, 1.0, out.view_mut());
514 }
515 }
516
517 fn mul_basis_columns_into(&self, start: usize, mut out: ArrayViewMut2<'_, f64>) {
518 if self.dense.is_none() && self.operators.len() == 1 {
519 self.operators[0].mul_basis_columns_into(start, out);
520 return;
521 }
522
523 out.fill(0.0);
524 let cols = out.ncols();
525 let end = start + cols;
526 if let Some(dense) = self.dense.as_ref() {
527 out += &dense.slice(ndarray::s![.., start..end]);
528 }
529 let mut work = Array2::<f64>::zeros((out.nrows(), cols));
530 for op in &self.operators {
531 op.mul_basis_columns_into(start, work.view_mut());
532 out += &work;
533 }
534 }
535
536 fn scaled_add_mul_vec(
537 &self,
538 v: ArrayView1<'_, f64>,
539 scale: f64,
540 mut out: ArrayViewMut1<'_, f64>,
541 ) {
542 if scale == 0.0 {
543 return;
544 }
545 if self.dense.is_none() && self.operators.len() == 1 {
546 self.operators[0].scaled_add_mul_vec(v, scale, out);
547 return;
548 }
549
550 if let Some(dense) = self.dense.as_ref() {
551 dense_matvec_scaled_add_into(dense, v, scale, out.view_mut());
552 }
553 for op in &self.operators {
554 op.scaled_add_mul_vec(v, scale, out.view_mut());
555 }
556 }
557
558 fn mul_mat(&self, factor: &Array2<f64>) -> Array2<f64> {
563 if self.dense.is_none() && self.operators.len() == 1 {
564 return self.operators[0].mul_mat(factor);
565 }
566 let p = factor.nrows();
567 let k = factor.ncols();
568 let mut out = Array2::<f64>::zeros((p, k));
569 if let Some(dense) = self.dense.as_ref() {
570 out += &dense.dot(factor);
571 }
572 for op in &self.operators {
573 out += &op.mul_mat(factor);
574 }
575 out
576 }
577
578 fn trace_projected_factor(&self, factor: &Array2<f64>) -> f64 {
579 if self.dense.is_none() && self.operators.len() == 1 {
580 return self.operators[0].trace_projected_factor(factor);
581 }
582
583 let mut trace = 0.0;
584 if let Some(dense) = self.dense.as_ref() {
585 let dense_factor = dense.dot(factor);
586 trace += factor
587 .iter()
588 .zip(dense_factor.iter())
589 .map(|(&f, &bf)| f * bf)
590 .sum::<f64>();
591 }
592 trace += composite_trace_implicit_batched(&self.operators, factor, None);
593 trace
594 }
595
596 fn trace_projected_factor_cached(
597 &self,
598 factor: &Array2<f64>,
599 cache: &ProjectedFactorCache,
600 ) -> f64 {
601 if self.dense.is_none() && self.operators.len() == 1 {
602 return self.operators[0].trace_projected_factor_cached(factor, cache);
603 }
604
605 let mut trace = 0.0;
606 if let Some(dense) = self.dense.as_ref() {
607 let dense_factor = dense.dot(factor);
608 trace += factor
609 .iter()
610 .zip(dense_factor.iter())
611 .map(|(&f, &bf)| f * bf)
612 .sum::<f64>();
613 }
614 trace += composite_trace_implicit_batched(&self.operators, factor, Some(cache));
615 trace
616 }
617
618 fn projected_matrix(&self, factor: &Array2<f64>) -> Array2<f64> {
619 if self.dense.is_none() && self.operators.len() == 1 {
620 return self.operators[0].projected_matrix(factor);
621 }
622
623 let rank = factor.ncols();
624 let mut projected = Array2::<f64>::zeros((rank, rank));
625 if let Some(dense) = self.dense.as_ref() {
626 let mf = gam_linalg::faer_ndarray::fast_ab(dense, factor);
627 projected += &gam_linalg::faer_ndarray::fast_atb(factor, &mf);
628 }
629 for op in &self.operators {
630 projected += &op.projected_matrix(factor);
631 }
632 projected
633 }
634
635 fn projected_matrix_cached(
636 &self,
637 factor: &Array2<f64>,
638 cache: &ProjectedFactorCache,
639 ) -> Array2<f64> {
640 if self.dense.is_none() && self.operators.len() == 1 {
641 return self.operators[0].projected_matrix_cached(factor, cache);
642 }
643
644 let rank = factor.ncols();
645 let mut projected = Array2::<f64>::zeros((rank, rank));
646 if let Some(dense) = self.dense.as_ref() {
647 let mf = gam_linalg::faer_ndarray::fast_ab(dense, factor);
648 projected += &gam_linalg::faer_ndarray::fast_atb(factor, &mf);
649 }
650 for op in &self.operators {
651 projected += &op.projected_matrix_cached(factor, cache);
652 }
653 projected
654 }
655
656 fn bilinear(&self, v: &Array1<f64>, u: &Array1<f64>) -> f64 {
657 let mut total = 0.0;
658 if let Some(dense) = self.dense.as_ref() {
659 total += dense_bilinear(dense, v.view(), u.view());
660 }
661 for op in &self.operators {
662 total += op.bilinear(v, u);
663 }
664 total
665 }
666
667 fn bilinear_view(&self, v: ArrayView1<'_, f64>, u: ArrayView1<'_, f64>) -> f64 {
668 let mut total = 0.0;
669 if let Some(dense) = self.dense.as_ref() {
670 total += dense_bilinear(dense, v, u);
671 }
672 for op in &self.operators {
673 total += op.bilinear_view(v, u);
674 }
675 total
676 }
677
678 fn to_dense(&self) -> Array2<f64> {
679 let mut out = self
680 .dense
681 .clone()
682 .unwrap_or_else(|| Array2::<f64>::zeros((self.dim_hint, self.dim_hint)));
683 for op in &self.operators {
684 out += &op.to_dense();
685 }
686 out
687 }
688
689 fn is_implicit(&self) -> bool {
690 self.operators.iter().any(|op| op.is_implicit())
691 }
692}
693
694mod implicit_matvec_scratch {
709 use std::cell::RefCell;
710
711 pub(super) struct Scratch {
712 pub x_v: Vec<f64>,
713 pub n_work: Vec<f64>,
714 pub p_work: Vec<f64>,
715 }
716
717 impl Scratch {
718 pub(crate) const fn new() -> Self {
719 Self {
720 x_v: Vec::new(),
721 n_work: Vec::new(),
722 p_work: Vec::new(),
723 }
724 }
725 }
726
727 thread_local! {
728 static SCRATCH: RefCell<Scratch> = const { RefCell::new(Scratch::new()) };
729 }
730
731 pub(super) fn with<R>(f: impl FnOnce(&mut Scratch) -> R) -> R {
732 SCRATCH.with(|cell| f(&mut cell.borrow_mut()))
733 }
734}
735
736pub struct ImplicitHyperOperator {
737 pub implicit_deriv: std::sync::Arc<gam_terms::basis::ImplicitDesignPsiDerivative>,
739 pub axis: usize,
741 pub(crate) x_design: std::sync::Arc<DesignMatrix>,
743 pub(crate) w_diag: gam_linalg::matrix::SignedWeightsArc,
749 pub s_psi: Array2<f64>,
751 pub(crate) p: usize,
753 pub c_x_psi_beta: Option<std::sync::Arc<Array1<f64>>>,
760}
761
762impl HyperOperator for ImplicitHyperOperator {
763 fn dim(&self) -> usize {
764 self.p
765 }
766
767 fn mul_vec(&self, v: &Array1<f64>) -> Array1<f64> {
768 let mut out = Array1::<f64>::zeros(self.p);
774 self.mul_vec_into(v.view(), out.view_mut());
775 out
776 }
777
778 fn mul_vec_view(&self, v: ArrayView1<'_, f64>) -> Array1<f64> {
779 let mut out = Array1::<f64>::zeros(self.p);
780 self.mul_vec_into(v, out.view_mut());
781 out
782 }
783
784 fn mul_vec_into(&self, v: ArrayView1<'_, f64>, out: ArrayViewMut1<'_, f64>) {
785 assert_eq!(v.len(), self.p);
786 let n_obs = self.w_diag.len();
787 implicit_matvec_scratch::with(|s| {
791 s.x_v.clear();
792 s.x_v.resize(n_obs, 0.0);
793 s.n_work.clear();
794 s.n_work.resize(n_obs, 0.0);
795 s.p_work.clear();
796 s.p_work.resize(self.p, 0.0);
797 let mut x_v_view = ndarray::ArrayViewMut1::from(s.x_v.as_mut_slice());
798 let n_work_view = ndarray::ArrayViewMut1::from(s.n_work.as_mut_slice());
799 let p_work_view = ndarray::ArrayViewMut1::from(s.p_work.as_mut_slice());
800 design_matrix_apply_view_into(&self.x_design, v, x_v_view.view_mut());
801 self.matvec_with_shared_xz_into(x_v_view.view(), v, out, n_work_view, p_work_view);
802 });
803 }
804
805 fn mul_basis_columns_into(&self, start: usize, mut out: ArrayViewMut2<'_, f64>) {
806 let cols = out.ncols();
807 assert!(start + cols <= self.p);
808
809 let n_obs = self.w_diag.len();
810 let mut basis = Array1::<f64>::zeros(self.p);
811 let mut x_col = Array1::<f64>::zeros(n_obs);
812 let mut dx_col = Array1::<f64>::zeros(n_obs);
813 let mut weighted = Array1::<f64>::zeros(n_obs);
814 let mut term = Array1::<f64>::zeros(self.p);
815
816 for local_col in 0..cols {
817 let global_col = start + local_col;
818 let mut out_col = out.column_mut(local_col);
819 out_col.assign(&self.s_psi.column(global_col));
820
821 design_matrix_column_into(&self.x_design, global_col, x_col.view_mut());
822 Zip::from(weighted.view_mut())
823 .and(self.w_diag.view())
824 .and(x_col.view())
825 .par_for_each(|dst, &w, &x| *dst = w * x);
826 term.assign(
827 &self
828 .implicit_deriv
829 .transpose_mul(self.axis, &weighted.view())
830 .expect("radial scalar evaluation failed during implicit hyper transpose_mul"),
831 );
832 out_col += &term;
833
834 basis[global_col] = 1.0;
835 dx_col.assign(
836 &self
837 .implicit_deriv
838 .forward_mul(self.axis, &basis.view())
839 .expect("radial scalar evaluation failed during implicit hyper forward_mul"),
840 );
841 basis[global_col] = 0.0;
842
843 Zip::from(weighted.view_mut())
844 .and(self.w_diag.view())
845 .and(dx_col.view())
846 .par_for_each(|dst, &w, &dx| *dst = w * dx);
847 design_matrix_transpose_apply_view_into(
848 &self.x_design,
849 weighted.view(),
850 term.view_mut(),
851 );
852 out_col += &term;
853
854 self.accumulate_c_correction_xt_into(
856 x_col.view(),
857 weighted.view_mut(),
858 term.view_mut(),
859 out_col,
860 );
861 }
862 }
863
864 fn bilinear(&self, v: &Array1<f64>, u: &Array1<f64>) -> f64 {
865 self.bilinear_view(v.view(), u.view())
866 }
867
868 fn bilinear_view(&self, v: ArrayView1<'_, f64>, u: ArrayView1<'_, f64>) -> f64 {
869 assert_eq!(v.len(), self.p);
870 assert_eq!(u.len(), self.p);
871
872 let x_v = design_matrix_apply_view(&self.x_design, v);
873 let x_u = design_matrix_apply_view(&self.x_design, u);
874 let dx_v = self
875 .implicit_deriv
876 .forward_mul(self.axis, &v)
877 .expect("radial scalar evaluation failed during implicit hyper forward_mul");
878 let dx_u = self
879 .implicit_deriv
880 .forward_mul(self.axis, &u)
881 .expect("radial scalar evaluation failed during implicit hyper forward_mul");
882
883 let w = &*self.w_diag;
884 let mut design = 0.0;
885 for i in 0..w.len() {
886 design += dx_v[i] * w[i] * x_u[i];
887 design += dx_u[i] * w[i] * x_v[i];
888 }
889
890 design += self.c_correction_bilinear(&x_v, &x_u);
891
892 let penalty = dense_bilinear(&self.s_psi, v, u);
893
894 design + penalty
895 }
896
897 fn is_implicit(&self) -> bool {
898 true
899 }
900
901 fn as_any(&self) -> &(dyn std::any::Any + 'static) {
902 self
903 }
904
905 fn trace_projected_factor(&self, factor: &Array2<f64>) -> f64 {
941 assert_eq!(factor.nrows(), self.p);
942 let n_obs = self.w_diag.len();
943 let rank = factor.ncols();
944 if rank == 0 || n_obs == 0 {
945 return 0.0;
946 }
947 let xf = self.compute_xf(factor);
948 self.trace_projected_factor_with_xf(factor, xf.view())
949 }
950
951 fn trace_projected_factor_cached(
962 &self,
963 factor: &Array2<f64>,
964 cache: &ProjectedFactorCache,
965 ) -> f64 {
966 assert_eq!(factor.nrows(), self.p);
967 let n_obs = self.w_diag.len();
968 let rank = factor.ncols();
969 if rank == 0 || n_obs == 0 {
970 return 0.0;
971 }
972 let xf = self.cached_xf(factor, cache);
973 self.trace_projected_factor_with_xf(factor, xf.view())
974 }
975}
976
977pub(crate) fn byte_balanced_row_chunk(cols: usize, n_rows: usize) -> usize {
982 const TARGET_BYTES: usize = 8 * 1024 * 1024;
983 const MIN_CHUNK_ROWS: usize = 512;
984 let bytes_per_row = cols.max(1) * std::mem::size_of::<f64>();
985 (TARGET_BYTES / bytes_per_row)
986 .max(MIN_CHUNK_ROWS)
987 .min(n_rows)
988}
989
990impl ImplicitHyperOperator {
991 pub(crate) fn compute_xf(&self, factor: &Array2<f64>) -> Array2<f64> {
999 let n_obs = self.w_diag.len();
1000 let rank = factor.ncols();
1001 let mut xf = Array2::<f64>::zeros((n_obs, rank));
1002 let chunk_rows = byte_balanced_row_chunk(self.p + rank, n_obs);
1003 let mut start = 0usize;
1004 while start < n_obs {
1005 let end = (start + chunk_rows).min(n_obs);
1006 let rows = self
1007 .x_design
1008 .try_row_chunk(start..end)
1009 .unwrap_or_else(|err| {
1016 reml_contract_panic(format!(
1018 "ImplicitHyperOperator::compute_xf row chunk failed: {err}"
1019 ))
1020 });
1021 let block = gam_linalg::faer_ndarray::fast_ab(&rows, factor);
1022 xf.slice_mut(ndarray::s![start..end, ..]).assign(&block);
1023 start = end;
1024 }
1025 xf
1026 }
1027
1028 pub(crate) fn cached_xf(
1035 &self,
1036 factor: &Array2<f64>,
1037 cache: &ProjectedFactorCache,
1038 ) -> Arc<Array2<f64>> {
1039 let design_id = Arc::as_ptr(&self.x_design) as usize;
1040 let key = ProjectedFactorKey::from_factor_view(design_id, factor.view());
1041 cache.get_or_insert_with(key, || self.compute_xf(factor))
1042 }
1043
1044 pub(crate) fn trace_projected_factor_with_xf(
1051 &self,
1052 factor: &Array2<f64>,
1053 xf: ArrayView2<'_, f64>,
1054 ) -> f64 {
1055 let rank = factor.ncols();
1056 let n_obs = self.w_diag.len();
1057 assert_eq!(xf.dim(), (n_obs, rank));
1058
1059 let u_knot = self.implicit_deriv.unproject_matrix(&factor.view());
1061
1062 let chunk_rows = byte_balanced_row_chunk(self.p + rank, n_obs);
1065
1066 let w = self.w_diag.as_ref();
1067 let c_opt = self.c_x_psi_beta.as_ref().map(|arc| arc.as_ref());
1068 let mut design_total = 0.0_f64;
1069 let mut correction_total = 0.0_f64;
1070 let mut start = 0usize;
1071 while start < n_obs {
1072 let end = (start + chunk_rows).min(n_obs);
1073 let chunk_n = end - start;
1074
1075 let xf_chunk = xf.slice(ndarray::s![start..end, ..]);
1077
1078 let kd_chunk = self
1081 .implicit_deriv
1082 .row_chunk_first_raw(self.axis, start..end)
1083 .expect("radial scalar evaluation failed during implicit hyper forward_mul_matrix");
1084 let dxf_chunk = gam_linalg::faer_ndarray::fast_ab(&kd_chunk, &u_knot);
1085
1086 for i_local in 0..chunk_n {
1088 let i = start + i_local;
1089 let w_i = w[i];
1090 let dxf_row = dxf_chunk.row(i_local);
1091 let xf_row = xf_chunk.row(i_local);
1092 for k in 0..rank {
1093 design_total += dxf_row[k] * w_i * xf_row[k];
1094 }
1095 if let Some(c) = c_opt {
1096 let c_i = c[i];
1097 for k in 0..rank {
1098 let v = xf_row[k];
1099 correction_total += c_i * v * v;
1100 }
1101 }
1102 }
1103 start = end;
1104 }
1105
1106 let s_f = self.s_psi.dot(factor);
1108 let penalty: f64 = factor.iter().zip(s_f.iter()).map(|(&f, &s)| f * s).sum();
1109
1110 2.0 * design_total + correction_total + penalty
1111 }
1112
1113 pub(crate) fn trace_projected_factor_all_axes_with_xf(
1118 &self,
1119 factor: &Array2<f64>,
1120 xf: ArrayView2<'_, f64>,
1121 axes: &[(usize, &Array2<f64>, Option<&Array1<f64>>)],
1122 ) -> Vec<f64> {
1123 let rank = factor.ncols();
1124 let n_obs = self.w_diag.len();
1125 assert_eq!(xf.dim(), (n_obs, rank));
1126
1127 let u_knot = self.implicit_deriv.unproject_matrix(&factor.view());
1128
1129 let chunk_rows = byte_balanced_row_chunk(self.p + rank, n_obs.max(1));
1130
1131 let w = self.w_diag.as_ref();
1132 let mut design_totals = vec![0.0_f64; axes.len()];
1133 let mut correction_totals = vec![0.0_f64; axes.len()];
1134
1135 let mut start = 0usize;
1136 while start < n_obs {
1137 let end = (start + chunk_rows).min(n_obs);
1138 let chunk_n = end - start;
1139 let xf_chunk = xf.slice(ndarray::s![start..end, ..]);
1140
1141 for (axis_idx, (axis, _s_psi, c_opt_axis)) in axes.iter().enumerate() {
1142 let kd_chunk = self
1143 .implicit_deriv
1144 .row_chunk_first_raw(*axis, start..end)
1145 .expect(
1146 "radial scalar evaluation failed during \
1147 trace_projected_factor_all_axes_with_xf",
1148 );
1149 let dxf_chunk = gam_linalg::faer_ndarray::fast_ab(&kd_chunk, &u_knot);
1150
1151 for i_local in 0..chunk_n {
1152 let i = start + i_local;
1153 let w_i = w[i];
1154 let dxf_row = dxf_chunk.row(i_local);
1155 let xf_row = xf_chunk.row(i_local);
1156 for k in 0..rank {
1157 design_totals[axis_idx] += dxf_row[k] * w_i * xf_row[k];
1158 }
1159 if let Some(c) = c_opt_axis {
1160 let c_i = c[i];
1161 for k in 0..rank {
1162 let v = xf_row[k];
1163 correction_totals[axis_idx] += c_i * v * v;
1164 }
1165 }
1166 }
1167 }
1168 start = end;
1169 }
1170
1171 axes.iter()
1172 .enumerate()
1173 .map(|(idx, (_axis, s_psi, _c_opt_axis))| {
1174 let s_f = s_psi.dot(factor);
1175 let penalty: f64 = factor.iter().zip(s_f.iter()).map(|(&f, &s)| f * s).sum();
1176 2.0 * design_totals[idx] + correction_totals[idx] + penalty
1177 })
1178 .collect()
1179 }
1180
1181 pub(crate) fn accumulate_c_correction_xt_into(
1182 &self,
1183 x_col: ArrayView1<'_, f64>,
1184 mut n_work: ArrayViewMut1<'_, f64>,
1185 mut p_work: ArrayViewMut1<'_, f64>,
1186 mut out_col: ArrayViewMut1<'_, f64>,
1187 ) {
1188 let Some(c_x_psi_beta) = self.c_x_psi_beta.as_ref() else {
1189 return;
1190 };
1191 let c = c_x_psi_beta.as_ref();
1192 assert_eq!(x_col.len(), c.len());
1193 assert_eq!(n_work.len(), c.len());
1194 assert_eq!(p_work.len(), self.p);
1195
1196 for i in 0..c.len() {
1197 n_work[i] = c[i] * x_col[i];
1198 }
1199 design_matrix_transpose_apply_view_into(&self.x_design, n_work.view(), p_work.view_mut());
1200 out_col += &p_work;
1201 }
1202
1203 pub(crate) fn c_correction_bilinear(&self, x_v: &Array1<f64>, x_u: &Array1<f64>) -> f64 {
1204 let Some(c_x_psi_beta) = self.c_x_psi_beta.as_ref() else {
1205 return 0.0;
1206 };
1207 x_v.iter()
1208 .zip(x_u.iter())
1209 .zip(c_x_psi_beta.iter())
1210 .map(|((&xv, &xu), &c)| xv * c * xu)
1211 .sum()
1212 }
1213
1214 pub fn bilinear_with_shared_x(
1239 &self,
1240 x_vec: &Array1<f64>,
1241 y_vec: &Array1<f64>,
1242 z: &Array1<f64>,
1243 u: &Array1<f64>,
1244 ) -> f64 {
1245 let dx_z = self
1247 .implicit_deriv
1248 .forward_mul(self.axis, &z.view())
1249 .expect("radial scalar evaluation failed during implicit hyper forward_mul");
1250 let dx_u = self
1251 .implicit_deriv
1252 .forward_mul(self.axis, &u.view())
1253 .expect("radial scalar evaluation failed during implicit hyper forward_mul");
1254
1255 let mut design = 0.0f64;
1256 let w = &*self.w_diag;
1257 for i in 0..x_vec.len() {
1258 let wi = w[i];
1259 design += dx_z[i] * wi * y_vec[i];
1260 design += dx_u[i] * wi * x_vec[i];
1261 }
1262
1263 if let Some(c_x_psi_beta) = self.c_x_psi_beta.as_ref() {
1267 let c = c_x_psi_beta.as_ref();
1268 for i in 0..x_vec.len() {
1269 design += y_vec[i] * c[i] * x_vec[i];
1270 }
1271 }
1272
1273 let penalty = dense_bilinear(&self.s_psi, z.view(), u.view());
1275
1276 design + penalty
1277 }
1278
1279 pub fn matvec_with_shared_xz_into(
1289 &self,
1290 x_vec: ArrayView1<'_, f64>,
1291 z: ArrayView1<'_, f64>,
1292 mut out: ArrayViewMut1<'_, f64>,
1293 mut n_work: ArrayViewMut1<'_, f64>,
1294 mut p_work: ArrayViewMut1<'_, f64>,
1295 ) {
1296 assert_eq!(z.len(), self.p);
1297 assert_eq!(out.len(), self.p);
1298 assert_eq!(n_work.len(), self.w_diag.len());
1299 assert_eq!(p_work.len(), self.p);
1300
1301 let w = &*self.w_diag;
1302 for i in 0..w.len() {
1303 n_work[i] = w[i] * x_vec[i];
1304 }
1305 let term1 = self
1306 .implicit_deriv
1307 .transpose_mul(self.axis, &n_work.view())
1308 .expect("radial scalar evaluation failed during implicit hyper transpose_mul");
1309 out.assign(&term1);
1310
1311 let dx_z = self
1312 .implicit_deriv
1313 .forward_mul(self.axis, &z)
1314 .expect("radial scalar evaluation failed during implicit hyper forward_mul");
1315 for i in 0..w.len() {
1316 n_work[i] = w[i] * dx_z[i];
1317 }
1318 design_matrix_transpose_apply_view_into(&self.x_design, n_work.view(), p_work.view_mut());
1319 out += &p_work;
1320
1321 dense_matvec_into(&self.s_psi, z, p_work.view_mut());
1322 out += &p_work;
1323
1324 if let Some(c_x_psi_beta) = self.c_x_psi_beta.as_ref() {
1326 let c = c_x_psi_beta.as_ref();
1327 for i in 0..w.len() {
1328 n_work[i] = c[i] * x_vec[i];
1329 }
1330 design_matrix_transpose_apply_view_into(
1331 &self.x_design,
1332 n_work.view(),
1333 p_work.view_mut(),
1334 );
1335 out += &p_work;
1336 }
1337 }
1338}
1339
1340pub struct SparseDirectionalHyperOperator {
1347 pub(crate) x_tau: super::super::HyperDesignDerivative,
1349 pub(crate) x_design: DesignMatrix,
1351 pub(crate) w_diag: gam_linalg::matrix::SignedWeightsArc,
1356 pub(crate) s_tau: Array2<f64>,
1358 pub(crate) c_x_tau_beta: Option<Array1<f64>>,
1360 pub(crate) firth_hphi_tau_partial: Option<Array2<f64>>,
1362 pub(crate) p: usize,
1364}
1365
1366impl HyperOperator for SparseDirectionalHyperOperator {
1367 fn dim(&self) -> usize {
1368 self.p
1369 }
1370
1371 fn mul_vec(&self, v: &Array1<f64>) -> Array1<f64> {
1372 assert_eq!(v.len(), self.p);
1373
1374 let x_v = self.x_design.matrixvectormultiply(v);
1376
1377 let w_x_v = &*self.w_diag * &x_v;
1379 let term1 = self
1380 .x_tau
1381 .transpose_mul_original(&w_x_v)
1382 .expect("SparseDirectionalHyperOperator transpose product should be shape-consistent");
1383
1384 let x_tau_v = self
1386 .x_tau
1387 .forward_mul_original(v)
1388 .expect("SparseDirectionalHyperOperator forward product should be shape-consistent");
1389 let w_x_tau_v = &*self.w_diag * &x_tau_v;
1390 let term2 = self.x_design.transpose_vector_multiply(&w_x_tau_v);
1391
1392 let term3 = self.s_tau.dot(v);
1394
1395 let mut out = term1 + term2 + term3;
1396
1397 if let Some(c_x_tau_beta) = self.c_x_tau_beta.as_ref() {
1399 let weighted = c_x_tau_beta * &x_v;
1400 out += &self.x_design.transpose_vector_multiply(&weighted);
1401 }
1402
1403 if let Some(hphi_tau_partial) = self.firth_hphi_tau_partial.as_ref() {
1405 out -= &hphi_tau_partial.dot(v);
1406 }
1407
1408 out
1409 }
1410
1411 fn is_implicit(&self) -> bool {
1412 false
1413 }
1414 fn as_any(&self) -> &(dyn std::any::Any + 'static) {
1415 self
1416 }
1417}
1418
1419pub struct GlmCurvatureCorrectionOperator {
1442 pub(crate) x_design: DesignMatrix,
1444 pub(crate) neg_c_xv: Array1<f64>,
1446 pub(crate) p: usize,
1448}
1449
1450impl HyperOperator for GlmCurvatureCorrectionOperator {
1451 fn dim(&self) -> usize {
1452 self.p
1453 }
1454
1455 fn mul_vec(&self, v: &Array1<f64>) -> Array1<f64> {
1456 assert_eq!(v.len(), self.p);
1457 let x_v = self.x_design.matrixvectormultiply(v);
1458 let weighted = &self.neg_c_xv * &x_v;
1459 self.x_design.transpose_vector_multiply(&weighted)
1460 }
1461
1462 fn as_any(&self) -> &(dyn std::any::Any + 'static) {
1463 self
1464 }
1465
1466 fn is_implicit(&self) -> bool {
1467 false
1468 }
1469}
1470
1471