1use crate::custom_family::{
14 CustomFamilyBlockPsiDerivative, CustomFamilyPsiDerivativeOperator,
15 EmbeddedImplicitPsiDerivativeOperator, build_embedded_dense_psi_operator,
16};
17use gam_linalg::matrix::{EmbeddedColumnBlock, EmbeddedSquareBlock};
18use gam_terms::smooth::{TermCollectionDesign, TermCollectionSpec};
19use crate::fit_orchestration::drivers::{
20 spatial_length_scale_term_indices, try_build_spatial_log_kappa_derivativeinfo_list,
21};
22use ndarray::Array2;
23use std::collections::HashMap;
24use std::ops::Range;
25use std::sync::Arc;
26
27pub(crate) fn wrap_spatial_implicit_psi_operator(
28 op: Arc<gam_terms::basis::ImplicitDesignPsiDerivative>,
29 global_range: Range<usize>,
30 total_p: usize,
31) -> Arc<dyn CustomFamilyPsiDerivativeOperator> {
32 Arc::new(
33 EmbeddedImplicitPsiDerivativeOperator::new(op, global_range, total_p)
34 .expect("spatial implicit psi operator should embed into full coefficient space"),
35 )
36}
37
38pub(crate) trait SpatialPsiBlockTransform {
52 fn transform_operator(
55 &self,
56 op: Arc<dyn CustomFamilyPsiDerivativeOperator>,
57 ) -> Result<Arc<dyn CustomFamilyPsiDerivativeOperator>, String> {
58 Ok(op)
59 }
60
61 fn transform_design(&self, design: Array2<f64>) -> Array2<f64> {
63 design
64 }
65
66 fn transform_penalty(&self, penalty: Array2<f64>) -> Array2<f64> {
68 penalty
69 }
70}
71
72pub(crate) struct IdentitySpatialPsiBlockTransform;
74
75impl SpatialPsiBlockTransform for IdentitySpatialPsiBlockTransform {}
76
77pub(crate) fn build_block_spatial_psi_derivatives(
78 data: ndarray::ArrayView2<'_, f64>,
79 resolvedspec: &TermCollectionSpec,
80 design: &TermCollectionDesign,
81) -> Result<Option<Vec<CustomFamilyBlockPsiDerivative>>, String> {
82 build_block_spatial_psi_derivatives_with_transform(
83 data,
84 resolvedspec,
85 design,
86 &IdentitySpatialPsiBlockTransform,
87 )
88}
89
90pub(crate) fn build_block_spatial_psi_derivatives_with_transform(
98 data: ndarray::ArrayView2<'_, f64>,
99 resolvedspec: &TermCollectionSpec,
100 design: &TermCollectionDesign,
101 transform: &dyn SpatialPsiBlockTransform,
102) -> Result<Option<Vec<CustomFamilyBlockPsiDerivative>>, String> {
103 let spatial_terms = spatial_length_scale_term_indices(resolvedspec);
104 let Some(info_list) =
105 try_build_spatial_log_kappa_derivativeinfo_list(data, resolvedspec, design, &spatial_terms)
106 .map_err(|e| e.to_string())?
107 else {
108 return Ok(None);
109 };
110 let psi_dim = info_list.len();
111 let axis_lookup: HashMap<(usize, usize), usize> = info_list
112 .iter()
113 .enumerate()
114 .filter_map(|(idx, info)| {
115 info.aniso_group_id
116 .map(|gid| ((gid, info.implicit_axis), idx))
117 })
118 .collect();
119 let collected: Result<Vec<CustomFamilyBlockPsiDerivative>, String> = info_list
120 .into_iter()
121 .enumerate()
122 .map(|(psi_idx, info)| {
123 let implicit_operator = info.implicit_operator.as_ref().map(|op| {
124 wrap_spatial_implicit_psi_operator(
125 Arc::clone(op),
126 info.global_range.clone(),
127 info.total_p,
128 )
129 });
130 let dense_operator = if implicit_operator.is_none() && !info.x_psi_local.is_empty() {
131 Some(build_embedded_dense_psi_operator(
132 &info.x_psi_local,
133 &info.x_psi_psi_local,
134 info.aniso_cross_designs.as_ref(),
135 info.global_range.clone(),
136 info.total_p,
137 info.implicit_axis,
138 )?)
139 } else {
140 None
141 };
142 let design_operator = implicit_operator
143 .or(dense_operator)
144 .map(|op| transform.transform_operator(op))
145 .transpose()?;
146 let materialize_dense_design =
147 !info.x_psi_local.is_empty() && design_operator.is_none();
148 let embed_design = |local: &Array2<f64>| -> Array2<f64> {
149 let embedded = if local.ncols() == 0 || local.nrows() == 0 {
150 Array2::<f64>::zeros((local.nrows(), info.total_p))
151 } else {
152 EmbeddedColumnBlock::new(local, info.global_range.clone(), info.total_p)
153 .materialize()
154 };
155 transform.transform_design(embedded)
156 };
157 let x_full = if materialize_dense_design {
158 embed_design(&info.x_psi_local)
159 } else {
160 Array2::<f64>::zeros((0, 0))
161 };
162 let penalty_indices = info.penalty_indices.clone();
163 let embed_penalty = |local: &Array2<f64>| -> Array2<f64> {
164 let embedded = if local.nrows() == 0 || local.ncols() == 0 {
165 Array2::<f64>::zeros((info.total_p, info.total_p))
166 } else {
167 EmbeddedSquareBlock::new(local, info.global_range.clone(), info.total_p)
168 .materialize()
169 };
170 transform.transform_penalty(embedded)
171 };
172 let s_components: Vec<(usize, Array2<f64>)> = info
173 .penalty_indices
174 .into_iter()
175 .zip(
176 info.s_psi_components_local
177 .into_iter()
178 .map(|local| embed_penalty(&local)),
179 )
180 .collect();
181 let x_psi_psi_rows = if materialize_dense_design {
183 let mut rows =
184 vec![Array2::<f64>::zeros((x_full.nrows(), x_full.ncols())); psi_dim];
185 rows[psi_idx] = embed_design(&info.x_psi_psi_local);
186 if let (Some(gid), Some(cross_designs)) =
187 (info.aniso_group_id, info.aniso_cross_designs.as_ref())
188 {
189 for (axis_j, local) in cross_designs {
190 if let Some(&global_j) = axis_lookup.get(&(gid, *axis_j)) {
191 rows[global_j] = embed_design(local);
192 }
193 }
194 }
195 Some(rows)
196 } else {
197 None
198 };
199 let mut s_psi_psi_comp_rows = vec![Vec::<(usize, Array2<f64>)>::new(); psi_dim];
201 s_psi_psi_comp_rows[psi_idx] = penalty_indices
202 .iter()
203 .copied()
204 .zip(info.s_psi_psi_components_local.iter().map(&embed_penalty))
205 .collect();
206 if let (Some(gid), Some(cross_penalty_provider)) = (
207 info.aniso_group_id,
208 info.aniso_cross_penalty_provider.as_ref(),
209 ) {
210 for ((group_id, axis_j), global_j) in &axis_lookup {
211 if *group_id != gid || *axis_j == info.implicit_axis {
212 continue;
213 }
214 let local_components =
215 cross_penalty_provider(*axis_j).map_err(|err| err.to_string())?;
216 if local_components.is_empty() {
217 continue;
218 }
219 s_psi_psi_comp_rows[*global_j] = penalty_indices
220 .iter()
221 .copied()
222 .zip(local_components.iter().map(embed_penalty))
223 .collect();
224 }
225 }
226 Ok(CustomFamilyBlockPsiDerivative {
227 penalty_index: Some(info.penalty_index),
228 x_psi: x_full,
229 s_psi: Array2::<f64>::zeros((0, 0)),
230 s_psi_components: Some(s_components),
231 s_psi_penalty_components: None,
232 x_psi_psi: x_psi_psi_rows,
233 s_psi_psi: None,
234 s_psi_psi_components: Some(s_psi_psi_comp_rows),
235 s_psi_psi_penalty_components: None,
236 implicit_operator: design_operator,
237 implicit_axis: info.implicit_axis,
238 implicit_group_id: info.aniso_group_id,
239 })
240 })
241 .collect();
242 Ok(Some(collected?))
243}
244
245#[cfg(test)]
246mod tests {
247 use super::*;
248 use gam_terms::basis::{CenterStrategy, MaternBasisSpec, MaternIdentifiability, MaternNu};
249 use crate::custom_family::resolve_custom_family_x_psi_psi_map;
250 use gam_terms::smooth::{
251 ShapeConstraint, SmoothBasisSpec, SmoothTermSpec, build_term_collection_design,
252 };
253 use crate::fit_orchestration::drivers::freeze_term_collection_from_design;
254 use gam_runtime::resource::ResourcePolicy;
255
256 #[test]
257 fn build_block_spatial_psi_derivatives_populates_aniso_cross_rows() {
258 let n = 10usize;
259 let mut data = Array2::<f64>::zeros((n, 2));
260 for i in 0..n {
261 let x0 = i as f64 / (n as f64 - 1.0);
262 let x1 = (0.37 * i as f64).sin() + 0.2 * x0;
263 data[[i, 0]] = x0;
264 data[[i, 1]] = x1;
265 }
266
267 let spec = TermCollectionSpec {
268 linear_terms: Vec::new(),
269 random_effect_terms: Vec::new(),
270 smooth_terms: vec![SmoothTermSpec {
271 name: "spatial".to_string(),
272 basis: SmoothBasisSpec::Matern {
273 feature_cols: vec![0, 1],
274 spec: MaternBasisSpec {
275 periodic: None,
276 center_strategy: CenterStrategy::EqualMass { num_centers: 6 },
277 length_scale: 0.8,
278 nu: MaternNu::ThreeHalves,
279 include_intercept: false,
280 double_penalty: false,
281 identifiability: MaternIdentifiability::CenterSumToZero,
282 aniso_log_scales: Some(vec![0.0, 0.0]),
283 nullspace_shrinkage_survived: None,
284 },
285 input_scales: None,
286 },
287 shape: ShapeConstraint::None,
288 joint_null_rotation: None,
289 }],
290 };
291 let base_design =
292 build_term_collection_design(data.view(), &spec).expect("build base spatial design");
293 let resolvedspec = freeze_term_collection_from_design(&spec, &base_design)
294 .expect("freeze spatial term spec");
295 let resolved_design = build_term_collection_design(data.view(), &resolvedspec)
296 .expect("rebuild frozen spatial design");
297 let spatial_terms = spatial_length_scale_term_indices(&resolvedspec);
298 let info_list = try_build_spatial_log_kappa_derivativeinfo_list(
299 data.view(),
300 &resolvedspec,
301 &resolved_design,
302 &spatial_terms,
303 )
304 .expect("build spatial derivative info list")
305 .expect("anisotropic derivative info");
306 let derivs =
307 build_block_spatial_psi_derivatives(data.view(), &resolvedspec, &resolved_design)
308 .expect("build custom-family spatial psi derivatives")
309 .expect("anisotropic spatial derivative rows");
310
311 assert_eq!(
312 derivs.len(),
313 2,
314 "2D anisotropic term should expose two psi rows"
315 );
316 assert_eq!(
317 info_list.len(),
318 2,
319 "info list should expose the same two psi rows"
320 );
321
322 let policy = ResourcePolicy::permissive_small_data();
323 let x_cross_01_map = resolve_custom_family_x_psi_psi_map(
324 &derivs[0],
325 &derivs[1],
326 1,
327 resolved_design.design.nrows(),
328 resolved_design.design.ncols(),
329 0..resolved_design.design.nrows(),
330 "psi0 cross design",
331 &policy,
332 )
333 .expect("resolve psi0 cross design");
334 let x_cross_10_map = resolve_custom_family_x_psi_psi_map(
335 &derivs[1],
336 &derivs[0],
337 0,
338 resolved_design.design.nrows(),
339 resolved_design.design.ncols(),
340 0..resolved_design.design.nrows(),
341 "psi1 cross design",
342 &policy,
343 )
344 .expect("resolve psi1 cross design");
345 let x_cross_01 = x_cross_01_map
346 .row_chunk(0..resolved_design.design.nrows())
347 .expect("materialize psi0 cross design");
348 let x_cross_10 = x_cross_10_map
349 .row_chunk(0..resolved_design.design.nrows())
350 .expect("materialize psi1 cross design");
351 assert_eq!(
352 x_cross_01.dim(),
353 (
354 resolved_design.design.nrows(),
355 resolved_design.design.ncols()
356 )
357 );
358 assert_eq!(
359 x_cross_10.dim(),
360 (
361 resolved_design.design.nrows(),
362 resolved_design.design.ncols()
363 )
364 );
365 let cross_designs_01 = info_list[0]
366 .aniso_cross_designs
367 .as_ref()
368 .expect("psi0 cross designs");
369 let cross_designs_10 = info_list[1]
370 .aniso_cross_designs
371 .as_ref()
372 .expect("psi1 cross designs");
373 assert_eq!(
374 cross_designs_01[0].0, 1,
375 "psi0 should point at psi1 cross design"
376 );
377 assert_eq!(
378 cross_designs_10[0].0, 0,
379 "psi1 should point at psi0 cross design"
380 );
381 let expected_x_cross_01 = EmbeddedColumnBlock::new(
382 &cross_designs_01[0].1,
383 info_list[0].global_range.clone(),
384 info_list[0].total_p,
385 )
386 .materialize();
387 let expected_x_cross_10 = EmbeddedColumnBlock::new(
388 &cross_designs_10[0].1,
389 info_list[1].global_range.clone(),
390 info_list[1].total_p,
391 )
392 .materialize();
393 assert!(
394 x_cross_01
395 .iter()
396 .zip(expected_x_cross_01.iter())
397 .all(|(lhs, rhs)| (*lhs - *rhs).abs() <= 1e-12),
398 "generic psi builder should embed the psi0->psi1 cross design into the off-diagonal row"
399 );
400 assert!(
401 x_cross_10
402 .iter()
403 .zip(expected_x_cross_10.iter())
404 .all(|(lhs, rhs)| (*lhs - *rhs).abs() <= 1e-12),
405 "generic psi builder should embed the psi1->psi0 cross design into the symmetric off-diagonal row"
406 );
407
408 let s_cross_01 = derivs[0]
409 .s_psi_psi_components
410 .as_ref()
411 .expect("psi0 penalty second-derivative rows")[1]
412 .clone();
413 let s_cross_10 = derivs[1]
414 .s_psi_psi_components
415 .as_ref()
416 .expect("psi1 penalty second-derivative rows")[0]
417 .clone();
418 let cross_penalties_01 = info_list[0]
419 .aniso_cross_penalty_provider
420 .as_ref()
421 .expect("psi0 cross penalty provider")(1)
422 .expect("psi0->psi1 cross penalties");
423 let cross_penalties_10 = info_list[1]
424 .aniso_cross_penalty_provider
425 .as_ref()
426 .expect("psi1 cross penalty provider")(0)
427 .expect("psi1->psi0 cross penalties");
428 assert_eq!(s_cross_01.len(), cross_penalties_01.len());
429 assert_eq!(s_cross_10.len(), cross_penalties_10.len());
430 for (((penalty_idx, actual), expected_local), expected_idx) in s_cross_01
431 .iter()
432 .zip(cross_penalties_01.iter())
433 .zip(info_list[0].penalty_indices.iter())
434 {
435 assert_eq!(*penalty_idx, *expected_idx);
436 let expected = EmbeddedSquareBlock::new(
437 expected_local,
438 info_list[0].global_range.clone(),
439 info_list[0].total_p,
440 )
441 .materialize();
442 assert_eq!(actual.dim(), expected.dim());
443 assert!(
444 actual
445 .iter()
446 .zip(expected.iter())
447 .all(|(lhs, rhs)| (*lhs - *rhs).abs() <= 1e-12),
448 "generic psi builder should embed each psi0->psi1 cross penalty component into the off-diagonal row"
449 );
450 }
451 for (((penalty_idx, actual), expected_local), expected_idx) in s_cross_10
452 .iter()
453 .zip(cross_penalties_10.iter())
454 .zip(info_list[1].penalty_indices.iter())
455 {
456 assert_eq!(*penalty_idx, *expected_idx);
457 let expected = EmbeddedSquareBlock::new(
458 expected_local,
459 info_list[1].global_range.clone(),
460 info_list[1].total_p,
461 )
462 .materialize();
463 assert_eq!(actual.dim(), expected.dim());
464 assert!(
465 actual
466 .iter()
467 .zip(expected.iter())
468 .all(|(lhs, rhs)| (*lhs - *rhs).abs() <= 1e-12),
469 "generic psi builder should embed each psi1->psi0 cross penalty component into the symmetric off-diagonal row"
470 );
471 }
472 }
473
474 #[test]
475 fn build_block_spatial_psi_derivatives_supports_3d_aniso_matern() {
476 let n = 24usize;
477 let mut data = Array2::<f64>::zeros((n, 3));
478 for i in 0..n {
479 let t = i as f64 / (n as f64 - 1.0);
480 data[[i, 0]] = t;
481 data[[i, 1]] = (2.0 * std::f64::consts::PI * t).sin();
482 data[[i, 2]] = (2.5 * std::f64::consts::PI * t).cos();
483 }
484
485 let spec = TermCollectionSpec {
486 linear_terms: Vec::new(),
487 random_effect_terms: Vec::new(),
488 smooth_terms: vec![SmoothTermSpec {
489 name: "spatial".to_string(),
490 basis: SmoothBasisSpec::Matern {
491 feature_cols: vec![0, 1, 2],
492 spec: MaternBasisSpec {
493 periodic: None,
494 center_strategy: CenterStrategy::EqualMass { num_centers: 6 },
495 length_scale: 0.45,
496 nu: MaternNu::ThreeHalves,
497 include_intercept: false,
498 double_penalty: false,
499 identifiability: MaternIdentifiability::CenterSumToZero,
500 aniso_log_scales: Some(vec![0.0, 0.0, 0.0]),
501 nullspace_shrinkage_survived: None,
502 },
503 input_scales: None,
504 },
505 shape: ShapeConstraint::None,
506 joint_null_rotation: None,
507 }],
508 };
509 let base_design =
510 build_term_collection_design(data.view(), &spec).expect("build base spatial design");
511 let resolvedspec = freeze_term_collection_from_design(&spec, &base_design)
512 .expect("freeze spatial term spec");
513 let resolved_design = build_term_collection_design(data.view(), &resolvedspec)
514 .expect("rebuild frozen spatial design");
515 let derivs =
516 build_block_spatial_psi_derivatives(data.view(), &resolvedspec, &resolved_design)
517 .expect("3D anisotropic Matern psi derivatives should build")
518 .expect("3D anisotropic Matern psi derivatives should be present");
519 assert_eq!(derivs.len(), 3);
520 assert!(derivs.iter().all(|deriv| deriv.implicit_operator.is_some()));
521 }
522}