Skip to main content

gam_models/
spatial_psi_bridge.rs

1//! Spatial-ψ derivative bridge.
2//!
3//! Self-contained translation layer between the smooth-side
4//! [`TermCollectionSpec`]/[`TermCollectionDesign`] world and the lower-level
5//! generic family engine ([`crate::custom_family`]). It takes the
6//! resolved spatial length-scale terms and produces the per-axis
7//! [`CustomFamilyBlockPsiDerivative`] blocks the engine consumes.
8//!
9//! Keeping this here (a *higher* layer than `custom_family`) lets the engine
10//! stay ignorant of `gam_terms::smooth`: family modules call into this
11//! bridge instead of the engine reaching up into smooth.
12
13use crate::custom_family::{
14    CustomFamilyBlockPsiDerivative, CustomFamilyPsiDerivativeOperator,
15    EmbeddedImplicitPsiDerivativeOperator, build_embedded_dense_psi_operator,
16};
17use gam_linalg::matrix::{EmbeddedColumnBlock, EmbeddedSquareBlock};
18use gam_terms::smooth::{TermCollectionDesign, TermCollectionSpec};
19use crate::fit_orchestration::drivers::{
20    spatial_length_scale_term_indices, try_build_spatial_log_kappa_derivativeinfo_list,
21};
22use ndarray::Array2;
23use std::collections::HashMap;
24use std::ops::Range;
25use std::sync::Arc;
26
27pub(crate) fn wrap_spatial_implicit_psi_operator(
28    op: Arc<gam_terms::basis::ImplicitDesignPsiDerivative>,
29    global_range: Range<usize>,
30    total_p: usize,
31) -> Arc<dyn CustomFamilyPsiDerivativeOperator> {
32    Arc::new(
33        EmbeddedImplicitPsiDerivativeOperator::new(op, global_range, total_p)
34            .expect("spatial implicit psi operator should embed into full coefficient space"),
35    )
36}
37
38/// Per-block transform applied by the shared spatial-ψ derivative engine.
39///
40/// The engine (see [`build_block_spatial_psi_derivatives_with_transform`]) owns
41/// the policy of *which* spatial length-scale terms become ψ-derivative blocks,
42/// how their embedded design/penalty matrices and implicit operators are
43/// assembled, and how anisotropic cross-axis rows are wired. Family modules that
44/// need a coordinate change applied uniformly to every assembled block — e.g. a
45/// time-varying survival covariate that tensorizes each spatial design row
46/// against a time basis — invert the dependency by *providing a transform* here
47/// instead of re-implementing the assembly loop.
48///
49/// All three hooks default to the identity, so the canonical (untransformed)
50/// path is just [`build_block_spatial_psi_derivatives`].
51pub(crate) trait SpatialPsiBlockTransform {
52    /// Transform an assembled implicit ψ-derivative operator (already embedded
53    /// into the full coefficient space). The default returns it unchanged.
54    fn transform_operator(
55        &self,
56        op: Arc<dyn CustomFamilyPsiDerivativeOperator>,
57    ) -> Result<Arc<dyn CustomFamilyPsiDerivativeOperator>, String> {
58        Ok(op)
59    }
60
61    /// Transform a materialized (already embedded) design block. Default: identity.
62    fn transform_design(&self, design: Array2<f64>) -> Array2<f64> {
63        design
64    }
65
66    /// Transform a materialized (already embedded) penalty block. Default: identity.
67    fn transform_penalty(&self, penalty: Array2<f64>) -> Array2<f64> {
68        penalty
69    }
70}
71
72/// The canonical no-op transform: blocks are emitted exactly as assembled.
73pub(crate) struct IdentitySpatialPsiBlockTransform;
74
75impl SpatialPsiBlockTransform for IdentitySpatialPsiBlockTransform {}
76
77pub(crate) fn build_block_spatial_psi_derivatives(
78    data: ndarray::ArrayView2<'_, f64>,
79    resolvedspec: &TermCollectionSpec,
80    design: &TermCollectionDesign,
81) -> Result<Option<Vec<CustomFamilyBlockPsiDerivative>>, String> {
82    build_block_spatial_psi_derivatives_with_transform(
83        data,
84        resolvedspec,
85        design,
86        &IdentitySpatialPsiBlockTransform,
87    )
88}
89
90/// Shared exact-derivative / spatial-ψ engine.
91///
92/// Builds the per-axis [`CustomFamilyBlockPsiDerivative`] blocks for every
93/// spatial length-scale term, threading every materialized design/penalty matrix
94/// and every assembled implicit operator through `transform`. Family modules
95/// consume this engine and supply a [`SpatialPsiBlockTransform`] rather than
96/// duplicating the block-assembly, cross-axis, and operator-embedding logic.
97pub(crate) fn build_block_spatial_psi_derivatives_with_transform(
98    data: ndarray::ArrayView2<'_, f64>,
99    resolvedspec: &TermCollectionSpec,
100    design: &TermCollectionDesign,
101    transform: &dyn SpatialPsiBlockTransform,
102) -> Result<Option<Vec<CustomFamilyBlockPsiDerivative>>, String> {
103    let spatial_terms = spatial_length_scale_term_indices(resolvedspec);
104    let Some(info_list) =
105        try_build_spatial_log_kappa_derivativeinfo_list(data, resolvedspec, design, &spatial_terms)
106            .map_err(|e| e.to_string())?
107    else {
108        return Ok(None);
109    };
110    let psi_dim = info_list.len();
111    let axis_lookup: HashMap<(usize, usize), usize> = info_list
112        .iter()
113        .enumerate()
114        .filter_map(|(idx, info)| {
115            info.aniso_group_id
116                .map(|gid| ((gid, info.implicit_axis), idx))
117        })
118        .collect();
119    let collected: Result<Vec<CustomFamilyBlockPsiDerivative>, String> = info_list
120        .into_iter()
121        .enumerate()
122        .map(|(psi_idx, info)| {
123            let implicit_operator = info.implicit_operator.as_ref().map(|op| {
124                wrap_spatial_implicit_psi_operator(
125                    Arc::clone(op),
126                    info.global_range.clone(),
127                    info.total_p,
128                )
129            });
130            let dense_operator = if implicit_operator.is_none() && !info.x_psi_local.is_empty() {
131                Some(build_embedded_dense_psi_operator(
132                    &info.x_psi_local,
133                    &info.x_psi_psi_local,
134                    info.aniso_cross_designs.as_ref(),
135                    info.global_range.clone(),
136                    info.total_p,
137                    info.implicit_axis,
138                )?)
139            } else {
140                None
141            };
142            let design_operator = implicit_operator
143                .or(dense_operator)
144                .map(|op| transform.transform_operator(op))
145                .transpose()?;
146            let materialize_dense_design =
147                !info.x_psi_local.is_empty() && design_operator.is_none();
148            let embed_design = |local: &Array2<f64>| -> Array2<f64> {
149                let embedded = if local.ncols() == 0 || local.nrows() == 0 {
150                    Array2::<f64>::zeros((local.nrows(), info.total_p))
151                } else {
152                    EmbeddedColumnBlock::new(local, info.global_range.clone(), info.total_p)
153                        .materialize()
154                };
155                transform.transform_design(embedded)
156            };
157            let x_full = if materialize_dense_design {
158                embed_design(&info.x_psi_local)
159            } else {
160                Array2::<f64>::zeros((0, 0))
161            };
162            let penalty_indices = info.penalty_indices.clone();
163            let embed_penalty = |local: &Array2<f64>| -> Array2<f64> {
164                let embedded = if local.nrows() == 0 || local.ncols() == 0 {
165                    Array2::<f64>::zeros((info.total_p, info.total_p))
166                } else {
167                    EmbeddedSquareBlock::new(local, info.global_range.clone(), info.total_p)
168                        .materialize()
169                };
170                transform.transform_penalty(embedded)
171            };
172            let s_components: Vec<(usize, Array2<f64>)> = info
173                .penalty_indices
174                .into_iter()
175                .zip(
176                    info.s_psi_components_local
177                        .into_iter()
178                        .map(|local| embed_penalty(&local)),
179                )
180                .collect();
181            // Build x_psi_psi rows with cross-derivative designs
182            let x_psi_psi_rows = if materialize_dense_design {
183                let mut rows =
184                    vec![Array2::<f64>::zeros((x_full.nrows(), x_full.ncols())); psi_dim];
185                rows[psi_idx] = embed_design(&info.x_psi_psi_local);
186                if let (Some(gid), Some(cross_designs)) =
187                    (info.aniso_group_id, info.aniso_cross_designs.as_ref())
188                {
189                    for (axis_j, local) in cross_designs {
190                        if let Some(&global_j) = axis_lookup.get(&(gid, *axis_j)) {
191                            rows[global_j] = embed_design(local);
192                        }
193                    }
194                }
195                Some(rows)
196            } else {
197                None
198            };
199            // Build s_psi_psi_components with cross-penalty terms
200            let mut s_psi_psi_comp_rows = vec![Vec::<(usize, Array2<f64>)>::new(); psi_dim];
201            s_psi_psi_comp_rows[psi_idx] = penalty_indices
202                .iter()
203                .copied()
204                .zip(info.s_psi_psi_components_local.iter().map(&embed_penalty))
205                .collect();
206            if let (Some(gid), Some(cross_penalty_provider)) = (
207                info.aniso_group_id,
208                info.aniso_cross_penalty_provider.as_ref(),
209            ) {
210                for ((group_id, axis_j), global_j) in &axis_lookup {
211                    if *group_id != gid || *axis_j == info.implicit_axis {
212                        continue;
213                    }
214                    let local_components =
215                        cross_penalty_provider(*axis_j).map_err(|err| err.to_string())?;
216                    if local_components.is_empty() {
217                        continue;
218                    }
219                    s_psi_psi_comp_rows[*global_j] = penalty_indices
220                        .iter()
221                        .copied()
222                        .zip(local_components.iter().map(embed_penalty))
223                        .collect();
224                }
225            }
226            Ok(CustomFamilyBlockPsiDerivative {
227                penalty_index: Some(info.penalty_index),
228                x_psi: x_full,
229                s_psi: Array2::<f64>::zeros((0, 0)),
230                s_psi_components: Some(s_components),
231                s_psi_penalty_components: None,
232                x_psi_psi: x_psi_psi_rows,
233                s_psi_psi: None,
234                s_psi_psi_components: Some(s_psi_psi_comp_rows),
235                s_psi_psi_penalty_components: None,
236                implicit_operator: design_operator,
237                implicit_axis: info.implicit_axis,
238                implicit_group_id: info.aniso_group_id,
239            })
240        })
241        .collect();
242    Ok(Some(collected?))
243}
244
245#[cfg(test)]
246mod tests {
247    use super::*;
248    use gam_terms::basis::{CenterStrategy, MaternBasisSpec, MaternIdentifiability, MaternNu};
249    use crate::custom_family::resolve_custom_family_x_psi_psi_map;
250    use gam_terms::smooth::{
251        ShapeConstraint, SmoothBasisSpec, SmoothTermSpec, build_term_collection_design,
252    };
253    use crate::fit_orchestration::drivers::freeze_term_collection_from_design;
254    use gam_runtime::resource::ResourcePolicy;
255
256    #[test]
257    fn build_block_spatial_psi_derivatives_populates_aniso_cross_rows() {
258        let n = 10usize;
259        let mut data = Array2::<f64>::zeros((n, 2));
260        for i in 0..n {
261            let x0 = i as f64 / (n as f64 - 1.0);
262            let x1 = (0.37 * i as f64).sin() + 0.2 * x0;
263            data[[i, 0]] = x0;
264            data[[i, 1]] = x1;
265        }
266
267        let spec = TermCollectionSpec {
268            linear_terms: Vec::new(),
269            random_effect_terms: Vec::new(),
270            smooth_terms: vec![SmoothTermSpec {
271                name: "spatial".to_string(),
272                basis: SmoothBasisSpec::Matern {
273                    feature_cols: vec![0, 1],
274                    spec: MaternBasisSpec {
275                        periodic: None,
276                        center_strategy: CenterStrategy::EqualMass { num_centers: 6 },
277                        length_scale: 0.8,
278                        nu: MaternNu::ThreeHalves,
279                        include_intercept: false,
280                        double_penalty: false,
281                        identifiability: MaternIdentifiability::CenterSumToZero,
282                        aniso_log_scales: Some(vec![0.0, 0.0]),
283                        nullspace_shrinkage_survived: None,
284                    },
285                    input_scales: None,
286                },
287                shape: ShapeConstraint::None,
288                joint_null_rotation: None,
289            }],
290        };
291        let base_design =
292            build_term_collection_design(data.view(), &spec).expect("build base spatial design");
293        let resolvedspec = freeze_term_collection_from_design(&spec, &base_design)
294            .expect("freeze spatial term spec");
295        let resolved_design = build_term_collection_design(data.view(), &resolvedspec)
296            .expect("rebuild frozen spatial design");
297        let spatial_terms = spatial_length_scale_term_indices(&resolvedspec);
298        let info_list = try_build_spatial_log_kappa_derivativeinfo_list(
299            data.view(),
300            &resolvedspec,
301            &resolved_design,
302            &spatial_terms,
303        )
304        .expect("build spatial derivative info list")
305        .expect("anisotropic derivative info");
306        let derivs =
307            build_block_spatial_psi_derivatives(data.view(), &resolvedspec, &resolved_design)
308                .expect("build custom-family spatial psi derivatives")
309                .expect("anisotropic spatial derivative rows");
310
311        assert_eq!(
312            derivs.len(),
313            2,
314            "2D anisotropic term should expose two psi rows"
315        );
316        assert_eq!(
317            info_list.len(),
318            2,
319            "info list should expose the same two psi rows"
320        );
321
322        let policy = ResourcePolicy::permissive_small_data();
323        let x_cross_01_map = resolve_custom_family_x_psi_psi_map(
324            &derivs[0],
325            &derivs[1],
326            1,
327            resolved_design.design.nrows(),
328            resolved_design.design.ncols(),
329            0..resolved_design.design.nrows(),
330            "psi0 cross design",
331            &policy,
332        )
333        .expect("resolve psi0 cross design");
334        let x_cross_10_map = resolve_custom_family_x_psi_psi_map(
335            &derivs[1],
336            &derivs[0],
337            0,
338            resolved_design.design.nrows(),
339            resolved_design.design.ncols(),
340            0..resolved_design.design.nrows(),
341            "psi1 cross design",
342            &policy,
343        )
344        .expect("resolve psi1 cross design");
345        let x_cross_01 = x_cross_01_map
346            .row_chunk(0..resolved_design.design.nrows())
347            .expect("materialize psi0 cross design");
348        let x_cross_10 = x_cross_10_map
349            .row_chunk(0..resolved_design.design.nrows())
350            .expect("materialize psi1 cross design");
351        assert_eq!(
352            x_cross_01.dim(),
353            (
354                resolved_design.design.nrows(),
355                resolved_design.design.ncols()
356            )
357        );
358        assert_eq!(
359            x_cross_10.dim(),
360            (
361                resolved_design.design.nrows(),
362                resolved_design.design.ncols()
363            )
364        );
365        let cross_designs_01 = info_list[0]
366            .aniso_cross_designs
367            .as_ref()
368            .expect("psi0 cross designs");
369        let cross_designs_10 = info_list[1]
370            .aniso_cross_designs
371            .as_ref()
372            .expect("psi1 cross designs");
373        assert_eq!(
374            cross_designs_01[0].0, 1,
375            "psi0 should point at psi1 cross design"
376        );
377        assert_eq!(
378            cross_designs_10[0].0, 0,
379            "psi1 should point at psi0 cross design"
380        );
381        let expected_x_cross_01 = EmbeddedColumnBlock::new(
382            &cross_designs_01[0].1,
383            info_list[0].global_range.clone(),
384            info_list[0].total_p,
385        )
386        .materialize();
387        let expected_x_cross_10 = EmbeddedColumnBlock::new(
388            &cross_designs_10[0].1,
389            info_list[1].global_range.clone(),
390            info_list[1].total_p,
391        )
392        .materialize();
393        assert!(
394            x_cross_01
395                .iter()
396                .zip(expected_x_cross_01.iter())
397                .all(|(lhs, rhs)| (*lhs - *rhs).abs() <= 1e-12),
398            "generic psi builder should embed the psi0->psi1 cross design into the off-diagonal row"
399        );
400        assert!(
401            x_cross_10
402                .iter()
403                .zip(expected_x_cross_10.iter())
404                .all(|(lhs, rhs)| (*lhs - *rhs).abs() <= 1e-12),
405            "generic psi builder should embed the psi1->psi0 cross design into the symmetric off-diagonal row"
406        );
407
408        let s_cross_01 = derivs[0]
409            .s_psi_psi_components
410            .as_ref()
411            .expect("psi0 penalty second-derivative rows")[1]
412            .clone();
413        let s_cross_10 = derivs[1]
414            .s_psi_psi_components
415            .as_ref()
416            .expect("psi1 penalty second-derivative rows")[0]
417            .clone();
418        let cross_penalties_01 = info_list[0]
419            .aniso_cross_penalty_provider
420            .as_ref()
421            .expect("psi0 cross penalty provider")(1)
422        .expect("psi0->psi1 cross penalties");
423        let cross_penalties_10 = info_list[1]
424            .aniso_cross_penalty_provider
425            .as_ref()
426            .expect("psi1 cross penalty provider")(0)
427        .expect("psi1->psi0 cross penalties");
428        assert_eq!(s_cross_01.len(), cross_penalties_01.len());
429        assert_eq!(s_cross_10.len(), cross_penalties_10.len());
430        for (((penalty_idx, actual), expected_local), expected_idx) in s_cross_01
431            .iter()
432            .zip(cross_penalties_01.iter())
433            .zip(info_list[0].penalty_indices.iter())
434        {
435            assert_eq!(*penalty_idx, *expected_idx);
436            let expected = EmbeddedSquareBlock::new(
437                expected_local,
438                info_list[0].global_range.clone(),
439                info_list[0].total_p,
440            )
441            .materialize();
442            assert_eq!(actual.dim(), expected.dim());
443            assert!(
444                actual
445                    .iter()
446                    .zip(expected.iter())
447                    .all(|(lhs, rhs)| (*lhs - *rhs).abs() <= 1e-12),
448                "generic psi builder should embed each psi0->psi1 cross penalty component into the off-diagonal row"
449            );
450        }
451        for (((penalty_idx, actual), expected_local), expected_idx) in s_cross_10
452            .iter()
453            .zip(cross_penalties_10.iter())
454            .zip(info_list[1].penalty_indices.iter())
455        {
456            assert_eq!(*penalty_idx, *expected_idx);
457            let expected = EmbeddedSquareBlock::new(
458                expected_local,
459                info_list[1].global_range.clone(),
460                info_list[1].total_p,
461            )
462            .materialize();
463            assert_eq!(actual.dim(), expected.dim());
464            assert!(
465                actual
466                    .iter()
467                    .zip(expected.iter())
468                    .all(|(lhs, rhs)| (*lhs - *rhs).abs() <= 1e-12),
469                "generic psi builder should embed each psi1->psi0 cross penalty component into the symmetric off-diagonal row"
470            );
471        }
472    }
473
474    #[test]
475    fn build_block_spatial_psi_derivatives_supports_3d_aniso_matern() {
476        let n = 24usize;
477        let mut data = Array2::<f64>::zeros((n, 3));
478        for i in 0..n {
479            let t = i as f64 / (n as f64 - 1.0);
480            data[[i, 0]] = t;
481            data[[i, 1]] = (2.0 * std::f64::consts::PI * t).sin();
482            data[[i, 2]] = (2.5 * std::f64::consts::PI * t).cos();
483        }
484
485        let spec = TermCollectionSpec {
486            linear_terms: Vec::new(),
487            random_effect_terms: Vec::new(),
488            smooth_terms: vec![SmoothTermSpec {
489                name: "spatial".to_string(),
490                basis: SmoothBasisSpec::Matern {
491                    feature_cols: vec![0, 1, 2],
492                    spec: MaternBasisSpec {
493                        periodic: None,
494                        center_strategy: CenterStrategy::EqualMass { num_centers: 6 },
495                        length_scale: 0.45,
496                        nu: MaternNu::ThreeHalves,
497                        include_intercept: false,
498                        double_penalty: false,
499                        identifiability: MaternIdentifiability::CenterSumToZero,
500                        aniso_log_scales: Some(vec![0.0, 0.0, 0.0]),
501                        nullspace_shrinkage_survived: None,
502                    },
503                    input_scales: None,
504                },
505                shape: ShapeConstraint::None,
506                joint_null_rotation: None,
507            }],
508        };
509        let base_design =
510            build_term_collection_design(data.view(), &spec).expect("build base spatial design");
511        let resolvedspec = freeze_term_collection_from_design(&spec, &base_design)
512            .expect("freeze spatial term spec");
513        let resolved_design = build_term_collection_design(data.view(), &resolvedspec)
514            .expect("rebuild frozen spatial design");
515        let derivs =
516            build_block_spatial_psi_derivatives(data.view(), &resolvedspec, &resolved_design)
517                .expect("3D anisotropic Matern psi derivatives should build")
518                .expect("3D anisotropic Matern psi derivatives should be present");
519        assert_eq!(derivs.len(), 3);
520        assert!(derivs.iter().all(|deriv| deriv.implicit_operator.is_some()));
521    }
522}