1use gam_terms::basis::create_ispline_derivative_dense;
2use gam_linalg::faer_ndarray::{FaerEigh, fast_ab};
3use crate::cubic_cell_kernel as exact_kernel;
4use gam_solve::pirls::LinearInequalityConstraints;
5use crate::util::span::span_index_for_breakpoints;
6use ndarray::{Array1, Array2, ArrayView2};
7
8fn validate_breakpoints(breakpoints: &[f64], label: &str) -> Result<(), String> {
11 if breakpoints.len() < 2 {
12 return Err(format!("{label} requires at least two breakpoints"));
13 }
14 if let Some((idx, window)) = breakpoints.windows(2).enumerate().find(|(_, window)| {
15 !window[0].is_finite() || !window[1].is_finite() || window[0] >= window[1]
16 }) {
17 return Err(format!(
18 "{label} requires strictly increasing finite breakpoints; breakpoints[{idx}]={:.6}, breakpoints[{}]={:.6}",
19 window[0],
20 idx + 1,
21 window[1]
22 ));
23 }
24 Ok::<(), _>(())
25}
26
27fn breakpoints_from_knots(knots: &[f64], label: &str) -> Result<Vec<f64>, String> {
30 let mut breakpoints = Vec::new();
31 for &knot in knots {
32 if breakpoints
33 .last()
34 .is_none_or(|prev: &f64| (knot - *prev).abs() > 1e-12)
35 {
36 breakpoints.push(knot);
37 }
38 }
39 validate_breakpoints(&breakpoints, label)?;
40 Ok(breakpoints)
41}
42
43pub(crate) const MONOTONICITY_SLACK_ROUNDOFF_TOL: f64 = -1e-10;
50
51#[derive(Debug, Clone)]
59pub enum DeviationRuntimeError {
60 InvalidInput { reason: String },
64 DimensionMismatch { reason: String },
67 NumericalFailure { reason: String },
70}
71
72impl_reason_error_boilerplate! {
73 DeviationRuntimeError {
74 InvalidInput,
75 DimensionMismatch,
76 NumericalFailure,
77 }
78}
79
80#[derive(Clone, Debug)]
93pub struct InstalledFlexBlock {
94 pub anchor_correction: Array2<f64>,
98 pub anchor_components: Vec<AnchorComponentTag>,
102}
103
104#[derive(Clone, Debug)]
105pub enum AnchorComponentTag {
106 Parametric {
112 block: ParametricAnchorBlock,
113 ncols: usize,
114 },
115 FlexEvaluation { ncols: usize },
120}
121
122#[derive(Clone, Copy, Debug, serde::Serialize, serde::Deserialize, PartialEq, Eq)]
123pub enum ParametricAnchorBlock {
124 Marginal,
125 Logslope,
126}
127
128pub(crate) fn integrate_polynomial_product(left: &[f64], right: &[f64], width: f64) -> f64 {
129 let mut total = 0.0;
130 for (left_power, &left_coeff) in left.iter().enumerate() {
131 for (right_power, &right_coeff) in right.iter().enumerate() {
132 let power = left_power + right_power + 1;
133 total += left_coeff * right_coeff * width.powi(power as i32) / power as f64;
134 }
135 }
136 total
137}
138
139#[derive(Clone, Debug)]
152pub struct DeviationRuntime {
153 pub(crate) degree: usize,
154 pub(crate) value_span_degree: usize,
155 pub(crate) basis_dim: usize,
156 pub(crate) monotonicity_eps: f64,
157 pub(crate) endpoint_points: Array1<f64>,
158 pub(crate) span_c0: Array2<f64>,
159 pub(crate) span_c1: Array2<f64>,
160 pub(crate) span_c2: Array2<f64>,
161 pub(crate) span_c3: Array2<f64>,
162 pub(crate) monotonicity_constraint_rows: Array2<f64>,
163 pub(crate) right_boundary_value_row: Array1<f64>,
167 pub(crate) installed_flex_block: Option<InstalledFlexBlock>,
170 pub(crate) anchor_rows_at_training: Option<Array2<f64>>,
176}
177
178pub(crate) fn raw_integrated_derivative_penalty(
195 endpoint_points: &Array1<f64>,
196 raw_span_c0: &Array2<f64>,
197 raw_span_c1: &Array2<f64>,
198 raw_span_c2: &Array2<f64>,
199 raw_span_c3: &Array2<f64>,
200 derivative_order: usize,
201) -> Result<Array2<f64>, String> {
202 let raw_dim = raw_span_c0.ncols();
203 let n_spans = endpoint_points.len().saturating_sub(1);
204 if raw_span_c1.ncols() != raw_dim
205 || raw_span_c2.ncols() != raw_dim
206 || raw_span_c3.ncols() != raw_dim
207 {
208 return Err("raw smoothness penalty: span coefficient column dimensions disagree".into());
209 }
210 let mut penalty = Array2::<f64>::zeros((raw_dim, raw_dim));
211 for span_idx in 0..n_spans {
212 let left = endpoint_points[span_idx];
213 let right = endpoint_points[span_idx + 1];
214 let width = right - left;
215 if !width.is_finite() || width <= 0.0 {
216 return Err(format!(
217 "raw smoothness penalty span {span_idx} has invalid width {width}"
218 ));
219 }
220 for i in 0..raw_dim {
221 let ci = raw_span_derivative_polynomial_coefficients(
222 span_idx,
223 i,
224 derivative_order,
225 raw_span_c0,
226 raw_span_c1,
227 raw_span_c2,
228 raw_span_c3,
229 );
230 for j in i..raw_dim {
231 let cj = raw_span_derivative_polynomial_coefficients(
232 span_idx,
233 j,
234 derivative_order,
235 raw_span_c0,
236 raw_span_c1,
237 raw_span_c2,
238 raw_span_c3,
239 );
240 let contribution = integrate_polynomial_product(&ci, &cj, width);
241 penalty[[i, j]] += contribution;
242 if i != j {
243 penalty[[j, i]] += contribution;
244 }
245 }
246 }
247 }
248 Ok(penalty)
249}
250
251pub(crate) fn raw_span_derivative_polynomial_coefficients(
256 span_idx: usize,
257 basis_idx: usize,
258 derivative_order: usize,
259 raw_span_c0: &Array2<f64>,
260 raw_span_c1: &Array2<f64>,
261 raw_span_c2: &Array2<f64>,
262 raw_span_c3: &Array2<f64>,
263) -> Vec<f64> {
264 let c0 = raw_span_c0[[span_idx, basis_idx]];
265 let c1 = raw_span_c1[[span_idx, basis_idx]];
266 let c2 = raw_span_c2[[span_idx, basis_idx]];
267 let c3 = raw_span_c3[[span_idx, basis_idx]];
268 match derivative_order {
269 0 => vec![c0, c1, c2, c3],
270 1 => vec![c1, 2.0 * c2, 3.0 * c3],
271 2 => vec![2.0 * c2, 6.0 * c3],
272 3 => vec![6.0 * c3],
273 _ => Vec::new(),
274 }
275}
276
277pub(crate) fn smoothness_nullspace_orthogonal_complement(
290 raw_penalty: &Array2<f64>,
291) -> Result<Array2<f64>, String> {
292 let n = raw_penalty.nrows();
293 if raw_penalty.ncols() != n {
294 return Err("smoothness penalty matrix must be square for null-space drop".to_string());
295 }
296 let (eigenvalues, eigenvectors) = raw_penalty
297 .eigh(faer::Side::Lower)
298 .map_err(|e| format!("raw smoothness penalty eigendecomposition failed: {e}"))?;
299 let evals = eigenvalues
300 .as_slice()
301 .ok_or_else(|| "raw smoothness penalty eigenvalues are not contiguous".to_string())?;
302 let threshold = gam_solve::estimate::reml::reml_outer_engine::positive_eigenvalue_threshold(evals);
303 let kept: Vec<usize> = evals
304 .iter()
305 .enumerate()
306 .filter_map(|(i, &v)| (v > threshold).then_some(i))
307 .collect();
308 if kept.is_empty() {
309 return Err(
310 "smoothness penalty has no positive eigenvalues; basis is entirely in the penalty's \
311 null space and cannot be identified after the smoothness null-space drop"
312 .to_string(),
313 );
314 }
315 if kept.len() == n {
316 return Err(
317 "smoothness penalty has no null directions; nothing to drop. The link-deviation \
318 basis was expected to carry a non-trivial null space (constants/linears) for \
319 absorption by the location block — check the configured penalty derivative order"
320 .to_string(),
321 );
322 }
323 let mut z = Array2::<f64>::zeros((n, kept.len()));
324 for (col_out, &col_in) in kept.iter().enumerate() {
325 z.column_mut(col_out).assign(&eigenvectors.column(col_in));
326 }
327 Ok(z)
328}
329
330pub(crate) fn build_quadratic_derivative_bernstein_constraints(
331 endpoint_points: &Array1<f64>,
332 span_c1: &Array2<f64>,
333 span_c2: &Array2<f64>,
334 span_c3: &Array2<f64>,
335) -> Result<Array2<f64>, String> {
336 let n_spans = endpoint_points.len().saturating_sub(1);
337 let basis_dim = span_c1.ncols();
338 let mut rows = Array2::<f64>::zeros((3 * n_spans, basis_dim));
339 for span_idx in 0..n_spans {
340 let width = endpoint_points[span_idx + 1] - endpoint_points[span_idx];
341 if !width.is_finite() || width <= 0.0 {
342 return Err(DeviationRuntimeError::InvalidInput {
343 reason: format!(
344 "DeviationRuntime monotonicity span {span_idx} has invalid width {width}"
345 ),
346 }
347 .into());
348 }
349 let left_row = 3 * span_idx;
350 let mid_row = left_row + 1;
351 let right_row = left_row + 2;
352 for basis_idx in 0..basis_dim {
353 let c1 = span_c1[[span_idx, basis_idx]];
354 let c2 = span_c2[[span_idx, basis_idx]];
355 let c3 = span_c3[[span_idx, basis_idx]];
356 rows[[left_row, basis_idx]] = c1;
367 rows[[mid_row, basis_idx]] = c1 + c2 * width;
368 rows[[right_row, basis_idx]] = c1 + 2.0 * c2 * width + 3.0 * c3 * width * width;
369 }
370 }
371 Ok(rows)
372}
373
374impl DeviationRuntime {
375 pub(crate) fn try_new(
390 knots: Array1<f64>,
391 monotonicity_eps: f64,
392 max_penalty_derivative_order: usize,
393 ) -> Result<Self, String> {
394 Self::try_new_with_smoothness_drop(knots, monotonicity_eps, max_penalty_derivative_order)
395 }
396
397 pub(super) fn try_new_with_smoothness_drop(
398 knots: Array1<f64>,
399 monotonicity_eps: f64,
400 max_penalty_derivative_order: usize,
401 ) -> Result<Self, String> {
402 if !monotonicity_eps.is_finite() || monotonicity_eps < 0.0 {
403 return Err(DeviationRuntimeError::InvalidInput {
404 reason: format!(
405 "DeviationRuntime monotonicity_eps must be finite and non-negative, got {monotonicity_eps}"
406 ),
407 }
408 .into());
409 }
410
411 let bkpts = breakpoints_from_knots(
412 knots.as_slice().ok_or_else(|| {
413 String::from(DeviationRuntimeError::InvalidInput {
414 reason: "DeviationRuntime knots are not contiguous".to_string(),
415 })
416 })?,
417 "DeviationRuntime breakpoints",
418 )?;
419 let endpoint_points = Array1::from_vec(bkpts);
420 if endpoint_points.len() < 3 {
421 return Err(DeviationRuntimeError::InvalidInput {
422 reason:
423 "DeviationRuntime requires at least two active knot spans and one interior node"
424 .to_string(),
425 }
426 .into());
427 }
428 let n_spans = endpoint_points.len() - 1;
429 for span_idx in 0..n_spans {
430 let left = endpoint_points[span_idx];
431 let right = endpoint_points[span_idx + 1];
432 let width = right - left;
433 if !width.is_finite() || width <= 0.0 {
434 return Err(DeviationRuntimeError::InvalidInput {
435 reason: format!(
436 "DeviationRuntime requires strictly increasing span endpoints at span {span_idx}: left={left}, right={right}"
437 ),
438 }
439 .into());
440 }
441 }
442 let span_lefts = Array1::from_iter((0..n_spans).map(|idx| endpoint_points[idx]));
443 let span_midpoints = Array1::from_iter(
444 (0..n_spans).map(|idx| 0.5 * (endpoint_points[idx] + endpoint_points[idx + 1])),
445 );
446 let right_endpoint = Array1::from_vec(vec![endpoint_points[n_spans]]);
447 let internal_degree = 2usize;
448 let raw_span_c0 =
449 create_ispline_derivative_dense(span_lefts.view(), &knots, internal_degree, 0)
450 .map_err(|e| {
451 String::from(DeviationRuntimeError::NumericalFailure {
452 reason: format!("DeviationRuntime cubic I-spline values failed: {e}"),
453 })
454 })?;
455 let raw_span_c1 =
456 create_ispline_derivative_dense(span_lefts.view(), &knots, internal_degree, 1)
457 .map_err(|e| {
458 String::from(DeviationRuntimeError::NumericalFailure {
459 reason: format!(
460 "DeviationRuntime cubic I-spline first derivatives failed: {e}"
461 ),
462 })
463 })?;
464 let raw_span_c2 =
465 create_ispline_derivative_dense(span_lefts.view(), &knots, internal_degree, 2)
466 .map_err(|e| {
467 String::from(DeviationRuntimeError::NumericalFailure {
468 reason: format!(
469 "DeviationRuntime cubic I-spline second derivatives failed: {e}"
470 ),
471 })
472 })?
473 .mapv(|value| 0.5 * value);
474 let raw_span_c3 =
475 create_ispline_derivative_dense(span_midpoints.view(), &knots, internal_degree, 3)
476 .map_err(|e| {
477 String::from(DeviationRuntimeError::NumericalFailure {
478 reason: format!(
479 "DeviationRuntime cubic I-spline third derivatives failed: {e}"
480 ),
481 })
482 })?
483 .mapv(|value| value / 6.0);
484 let raw_right_boundary_values =
485 create_ispline_derivative_dense(right_endpoint.view(), &knots, internal_degree, 0)
486 .map_err(|e| {
487 String::from(DeviationRuntimeError::NumericalFailure {
488 reason: format!(
489 "DeviationRuntime cubic I-spline right boundary failed: {e}"
490 ),
491 })
492 })?;
493 let raw_right_boundary_value_row = raw_right_boundary_values.row(0).to_owned();
494
495 if max_penalty_derivative_order == 0 {
496 return Err(
497 "DeviationRuntime requires max_penalty_derivative_order >= 1 so the basis can \
498 drop the corresponding smoothness null space; an order-0 (mass) penalty alone \
499 has no null space and would not require any drop"
500 .to_string(),
501 );
502 }
503 if max_penalty_derivative_order > 3 {
504 return Err(format!(
505 "DeviationRuntime cubic basis supports derivative orders up to 3; got max \
506 penalty derivative order {max_penalty_derivative_order}"
507 ));
508 }
509 let raw_smoothness_penalty = raw_integrated_derivative_penalty(
510 &endpoint_points,
511 &raw_span_c0,
512 &raw_span_c1,
513 &raw_span_c2,
514 &raw_span_c3,
515 max_penalty_derivative_order,
516 )?;
517 let coefficient_transform =
518 smoothness_nullspace_orthogonal_complement(&raw_smoothness_penalty)?;
519 let basis_dim = coefficient_transform.ncols();
520 let span_c0 = fast_ab(&raw_span_c0, &coefficient_transform);
521 let span_c1 = fast_ab(&raw_span_c1, &coefficient_transform);
522 let span_c2 = fast_ab(&raw_span_c2, &coefficient_transform);
523 let span_c3 = fast_ab(&raw_span_c3, &coefficient_transform);
524 let right_boundary_value_row = raw_right_boundary_value_row.dot(&coefficient_transform);
525 let monotonicity_constraint_rows = build_quadratic_derivative_bernstein_constraints(
526 &endpoint_points,
527 &span_c1,
528 &span_c2,
529 &span_c3,
530 )?;
531
532 Ok(Self {
533 degree: 3,
534 value_span_degree: 3,
535 basis_dim,
536 monotonicity_eps,
537 endpoint_points,
538 span_c0,
539 span_c1,
540 span_c2,
541 span_c3,
542 monotonicity_constraint_rows,
543 right_boundary_value_row,
544 installed_flex_block: None,
545 anchor_rows_at_training: None,
546 })
547 }
548
549 pub(crate) fn compose_anchor_orthogonalisation(
606 &mut self,
607 right_selector: &Array2<f64>,
608 installed_flex_block: Option<InstalledFlexBlock>,
609 ) -> Result<(), String> {
610 let old_dim = self.basis_dim;
611 if right_selector.nrows() != old_dim {
612 return Err(DeviationRuntimeError::DimensionMismatch {
613 reason: format!(
614 "DeviationRuntime cross-block transform shape mismatch: \
615 transform rows={}, expected basis_dim={}",
616 right_selector.nrows(),
617 old_dim,
618 ),
619 }
620 .into());
621 }
622 let new_dim = right_selector.ncols();
623 if new_dim == 0 {
624 return Err(DeviationRuntimeError::DimensionMismatch {
625 reason: "DeviationRuntime cross-block transform reduces basis dim to 0; \
626 the candidate's column span is fully aliased by the anchor block"
627 .to_string(),
628 }
629 .into());
630 }
631 if new_dim > old_dim {
632 return Err(DeviationRuntimeError::DimensionMismatch {
633 reason: format!(
634 "DeviationRuntime cross-block transform must not increase basis dim; \
635 got new_dim={} from old_dim={}",
636 new_dim, old_dim,
637 ),
638 }
639 .into());
640 }
641 if let Some(ref installed) = installed_flex_block {
642 let d_expected: usize = installed
643 .anchor_components
644 .iter()
645 .map(|c| match c {
646 AnchorComponentTag::Parametric { ncols, .. } => *ncols,
647 AnchorComponentTag::FlexEvaluation { ncols } => *ncols,
648 })
649 .sum();
650 if installed.anchor_correction.nrows() != d_expected {
651 return Err(DeviationRuntimeError::DimensionMismatch {
652 reason: format!(
653 "DeviationRuntime installed flex block: anchor_correction rows={}, expected sum-of-component-ncols={}",
654 installed.anchor_correction.nrows(),
655 d_expected,
656 ),
657 }
658 .into());
659 }
660 if installed.anchor_correction.ncols() != new_dim {
661 return Err(DeviationRuntimeError::DimensionMismatch {
662 reason: format!(
663 "DeviationRuntime installed flex block: anchor_correction cols={}, expected new basis dim {}",
664 installed.anchor_correction.ncols(),
665 new_dim,
666 ),
667 }
668 .into());
669 }
670 }
671 self.span_c0 = fast_ab(&self.span_c0, right_selector);
672 self.span_c1 = fast_ab(&self.span_c1, right_selector);
673 self.span_c2 = fast_ab(&self.span_c2, right_selector);
674 self.span_c3 = fast_ab(&self.span_c3, right_selector);
675 self.right_boundary_value_row = self.right_boundary_value_row.dot(right_selector);
678 self.monotonicity_constraint_rows =
683 fast_ab(&self.monotonicity_constraint_rows, right_selector);
684 self.basis_dim = new_dim;
685 self.installed_flex_block = installed_flex_block;
686 Ok(())
687 }
688
689 pub fn installed_flex_block(&self) -> Option<&InstalledFlexBlock> {
694 self.installed_flex_block.as_ref()
695 }
696
697 pub(crate) fn install_compiled_flex_block(
710 &mut self,
711 compiled: &gam_identifiability::families::compiler::CompiledBlock,
712 anchor_components: Vec<AnchorComponentTag>,
713 n_train_at_training: Array2<f64>,
714 ) -> Result<(), String> {
715 let m = compiled.anchor_correction.as_ref().ok_or_else(|| {
716 "DeviationRuntime::install_compiled_flex_block: compiled block has no \
717 anchor_correction — install requires a non-empty anchor union"
718 .to_string()
719 })?;
720 let installed = InstalledFlexBlock {
721 anchor_correction: m.clone(),
722 anchor_components,
723 };
724 self.anchor_rows_at_training = Some(n_train_at_training);
725 self.compose_anchor_orthogonalisation(&compiled.t_lw, Some(installed))
726 }
727
728 pub fn anchor_rows_at_training(&self) -> Option<&Array2<f64>> {
735 self.anchor_rows_at_training.as_ref()
736 }
737
738 pub fn design_with_anchor_rows(
744 &self,
745 values: &Array1<f64>,
746 anchor_rows: ArrayView2<f64>,
747 ) -> Result<Array2<f64>, String> {
748 let mut out = self.evaluate_span_polynomial_design_raw(values, 0)?;
749 if let Some(installed) = &self.installed_flex_block {
750 if anchor_rows.nrows() != values.len() {
751 return Err(DeviationRuntimeError::DimensionMismatch {
752 reason: format!(
753 "design_with_anchor_rows: anchor_rows has {} rows, expected {} (matching values)",
754 anchor_rows.nrows(),
755 values.len(),
756 ),
757 }
758 .into());
759 }
760 if anchor_rows.ncols() != installed.anchor_correction.nrows() {
761 return Err(DeviationRuntimeError::DimensionMismatch {
762 reason: format!(
763 "design_with_anchor_rows: anchor_rows has {} cols, expected {} (sum of component ncols)",
764 anchor_rows.ncols(),
765 installed.anchor_correction.nrows(),
766 ),
767 }
768 .into());
769 }
770 let subtract = anchor_rows.dot(&installed.anchor_correction);
771 out = out - subtract;
772 } else if anchor_rows.ncols() != 0 {
773 return Err(DeviationRuntimeError::DimensionMismatch {
776 reason: format!(
777 "design_with_anchor_rows: runtime has no installed flex block but anchor_rows has {} cols",
778 anchor_rows.ncols(),
779 ),
780 }
781 .into());
782 }
783 Ok(out)
784 }
785
786 pub(crate) fn design_at_training_with_residual(
789 &self,
790 values: &Array1<f64>,
791 ) -> Result<Array2<f64>, String> {
792 if let Some(rows) = self.anchor_rows_at_training.as_ref() {
793 self.design_with_anchor_rows(values, rows.view())
794 } else if self.installed_flex_block.is_some() {
795 Err(
796 "design_at_training_with_residual: runtime has installed_flex_block but no cached training anchor rows"
797 .to_string(),
798 )
799 } else {
800 self.design(values)
801 }
802 }
803
804 pub fn degree(&self) -> usize {
807 self.degree
808 }
809
810 pub fn value_span_degree(&self) -> usize {
811 self.value_span_degree
812 }
813
814 pub fn basis_dim(&self) -> usize {
815 self.basis_dim
816 }
817
818 pub fn monotonicity_eps(&self) -> f64 {
819 self.monotonicity_eps
820 }
821
822 pub fn span_c0(&self) -> &Array2<f64> {
823 &self.span_c0
824 }
825
826 pub fn span_c1(&self) -> &Array2<f64> {
827 &self.span_c1
828 }
829
830 pub fn span_c2(&self) -> &Array2<f64> {
831 &self.span_c2
832 }
833
834 pub fn span_c3(&self) -> &Array2<f64> {
835 &self.span_c3
836 }
837
838 pub(super) fn validate_beta_shape(
841 &self,
842 beta: &Array1<f64>,
843 label: &str,
844 ) -> Result<(), String> {
845 if beta.len() != self.basis_dim {
846 return Err(DeviationRuntimeError::DimensionMismatch {
847 reason: format!(
848 "{label} length mismatch: got {}, expected {}",
849 beta.len(),
850 self.basis_dim
851 ),
852 }
853 .into());
854 }
855 Ok::<(), _>(())
856 }
857
858 pub(super) fn evaluate_span_polynomial_design_raw(
863 &self,
864 values: &Array1<f64>,
865 derivative_order: usize,
866 ) -> Result<Array2<f64>, String> {
867 let (left_ep, right_ep) = self.support_interval()?;
868 let mut out = Array2::<f64>::zeros((values.len(), self.basis_dim));
869 for (row_idx, &value) in values.iter().enumerate() {
870 if !value.is_finite() {
871 return Err(DeviationRuntimeError::InvalidInput {
872 reason: format!(
873 "deviation runtime design value at row {row_idx} is non-finite ({value})"
874 ),
875 }
876 .into());
877 }
878 if value < left_ep {
879 if derivative_order == 0 {
880 out.row_mut(row_idx).assign(&self.span_c0.row(0));
881 }
882 continue;
883 }
884 if value > right_ep {
885 if derivative_order == 0 {
886 out.row_mut(row_idx)
887 .assign(&self.right_boundary_value_row.view());
888 }
889 continue;
890 }
891 let span_idx = self.left_biased_span_index_for(value)?;
892 let left = self.endpoint_points[span_idx];
893 let t = value - left;
894 for basis_idx in 0..self.basis_dim {
895 let c0 = self.span_c0[[span_idx, basis_idx]];
896 let c1 = self.span_c1[[span_idx, basis_idx]];
897 let c2 = self.span_c2[[span_idx, basis_idx]];
898 let c3 = self.span_c3[[span_idx, basis_idx]];
899 out[[row_idx, basis_idx]] = match derivative_order {
900 0 => c0 + c1 * t + c2 * t * t + c3 * t * t * t,
901 1 => c1 + 2.0 * c2 * t + 3.0 * c3 * t * t,
902 2 => 2.0 * c2 + 6.0 * c3 * t,
903 3 => 6.0 * c3,
904 4 => 0.0,
905 other => {
906 return Err(DeviationRuntimeError::InvalidInput {
907 reason: format!(
908 "deviation runtime only supports derivative orders up to 4, got {other}"
909 ),
910 }
911 .into());
912 }
913 };
914 }
915 }
916 Ok(out)
917 }
918
919 pub fn design(&self, values: &Array1<f64>) -> Result<Array2<f64>, String> {
925 assert!(
926 self.installed_flex_block.is_none(),
927 "DeviationRuntime::design called on a runtime with an installed flex block; \
928 use design_with_anchor_rows or design_at_training_with_residual instead"
929 );
930 self.evaluate_span_polynomial_design_raw(values, 0)
931 }
932
933 pub fn first_derivative_design(&self, values: &Array1<f64>) -> Result<Array2<f64>, String> {
934 self.evaluate_span_polynomial_design_raw(values, 1)
935 }
936
937 pub fn second_derivative_design(&self, values: &Array1<f64>) -> Result<Array2<f64>, String> {
938 self.evaluate_span_polynomial_design_raw(values, 2)
939 }
940
941 pub fn third_derivative_design(&self, values: &Array1<f64>) -> Result<Array2<f64>, String> {
942 self.evaluate_span_polynomial_design_raw(values, 3)
943 }
944
945 pub(crate) fn integrated_derivative_penalty_with_nullity(
946 &self,
947 derivative_order: usize,
948 ) -> Result<(Array2<f64>, usize), String> {
949 if derivative_order > self.value_span_degree {
950 return Err(DeviationRuntimeError::InvalidInput {
951 reason: format!(
952 "deviation penalty derivative order {derivative_order} exceeds value-basis degree {}",
953 self.value_span_degree
954 ),
955 }
956 .into());
957 }
958 let mut penalty = Array2::<f64>::zeros((self.basis_dim, self.basis_dim));
959 for span_idx in 0..self.span_count() {
960 let (left, right) = self.span_interval(span_idx)?;
961 let width = right - left;
962 if !width.is_finite() || width <= 0.0 {
963 return Err(DeviationRuntimeError::InvalidInput {
964 reason: format!("deviation penalty span {span_idx} has invalid width {width}"),
965 }
966 .into());
967 }
968 for i in 0..self.basis_dim {
969 let ci =
970 self.span_derivative_polynomial_coefficients(span_idx, i, derivative_order)?;
971 for j in i..self.basis_dim {
972 let cj = self.span_derivative_polynomial_coefficients(
973 span_idx,
974 j,
975 derivative_order,
976 )?;
977 let contribution = integrate_polynomial_product(&ci, &cj, width);
978 penalty[[i, j]] += contribution;
979 if i != j {
980 penalty[[j, i]] += contribution;
981 }
982 }
983 }
984 }
985 let (evals, _) = penalty.eigh(faer::Side::Lower).map_err(|e| {
986 String::from(DeviationRuntimeError::NumericalFailure {
987 reason: format!("deviation integrated penalty eigendecomposition failed: {e}"),
988 })
989 })?;
990 let threshold = gam_solve::estimate::reml::reml_outer_engine::positive_eigenvalue_threshold(
991 evals.as_slice().ok_or_else(|| {
992 String::from(DeviationRuntimeError::NumericalFailure {
993 reason: "deviation penalty eigenvalues are not contiguous".to_string(),
994 })
995 })?,
996 );
997 let rank = evals.iter().filter(|&&value| value > threshold).count();
998 let nullity = self.basis_dim.saturating_sub(rank);
999 Ok((penalty, nullity))
1000 }
1001
1002 pub(crate) fn structural_monotonicity_constraints(&self) -> LinearInequalityConstraints {
1003 LinearInequalityConstraints {
1004 a: self.monotonicity_constraint_rows.clone(),
1005 b: Array1::from_elem(
1006 self.monotonicity_constraint_rows.nrows(),
1007 self.monotonicity_eps - 1.0,
1008 ),
1009 }
1010 }
1011
1012 pub(super) fn span_count(&self) -> usize {
1015 self.endpoint_points.len().saturating_sub(1)
1016 }
1017
1018 pub fn breakpoints(&self) -> &Array1<f64> {
1019 &self.endpoint_points
1020 }
1021
1022 pub(super) fn span_interval(&self, span_idx: usize) -> Result<(f64, f64), String> {
1023 if span_idx >= self.span_count() {
1024 return Err(DeviationRuntimeError::InvalidInput {
1025 reason: format!(
1026 "deviation span index {} out of range for {} spans",
1027 span_idx,
1028 self.span_count()
1029 ),
1030 }
1031 .into());
1032 }
1033 Ok((
1034 self.endpoint_points[span_idx],
1035 self.endpoint_points[span_idx + 1],
1036 ))
1037 }
1038
1039 pub(super) fn span_index_for(&self, value: f64) -> Result<usize, String> {
1040 span_index_for_breakpoints(
1041 self.endpoint_points.as_slice().ok_or_else(|| {
1042 String::from(DeviationRuntimeError::InvalidInput {
1043 reason: "deviation runtime breakpoints are not contiguous".to_string(),
1044 })
1045 })?,
1046 value,
1047 "deviation span lookup",
1048 )
1049 }
1050
1051 pub(super) fn left_biased_span_index_for(&self, value: f64) -> Result<usize, String> {
1052 let mut span_idx = self.span_index_for(value)?;
1053 if span_idx > 0 && value == self.endpoint_points[span_idx] {
1057 span_idx -= 1;
1058 }
1059 Ok(span_idx)
1060 }
1061
1062 pub(super) fn span_derivative_polynomial_coefficients(
1063 &self,
1064 span_idx: usize,
1065 basis_idx: usize,
1066 derivative_order: usize,
1067 ) -> Result<Vec<f64>, String> {
1068 if span_idx >= self.span_count() {
1069 return Err(DeviationRuntimeError::InvalidInput {
1070 reason: format!(
1071 "deviation span index {} out of range for {} spans",
1072 span_idx,
1073 self.span_count()
1074 ),
1075 }
1076 .into());
1077 }
1078 if basis_idx >= self.basis_dim {
1079 return Err(DeviationRuntimeError::InvalidInput {
1080 reason: format!(
1081 "deviation basis index {} out of range for {} coefficients",
1082 basis_idx, self.basis_dim
1083 ),
1084 }
1085 .into());
1086 }
1087 let c0 = self.span_c0[[span_idx, basis_idx]];
1088 let c1 = self.span_c1[[span_idx, basis_idx]];
1089 let c2 = self.span_c2[[span_idx, basis_idx]];
1090 let c3 = self.span_c3[[span_idx, basis_idx]];
1091 match derivative_order {
1092 0 => Ok(vec![c0, c1, c2, c3]),
1093 1 => Ok(vec![c1, 2.0 * c2, 3.0 * c3]),
1094 2 => Ok(vec![2.0 * c2, 6.0 * c3]),
1095 3 => Ok(vec![6.0 * c3]),
1096 other => Err(DeviationRuntimeError::InvalidInput {
1097 reason: format!(
1098 "deviation polynomial coefficients only support derivative orders up to 3, got {other}"
1099 ),
1100 }
1101 .into()),
1102 }
1103 }
1104
1105 pub(crate) fn local_cubic_on_span(
1108 &self,
1109 beta: &Array1<f64>,
1110 span_idx: usize,
1111 ) -> Result<exact_kernel::LocalSpanCubic, String> {
1112 self.validate_beta_shape(beta, "deviation local cubic coefficients")?;
1113 let (left, right) = self.span_interval(span_idx)?;
1114 Ok(exact_kernel::LocalSpanCubic {
1115 left,
1116 right,
1117 c0: self.span_c0.row(span_idx).dot(beta),
1118 c1: self.span_c1.row(span_idx).dot(beta),
1119 c2: self.span_c2.row(span_idx).dot(beta),
1120 c3: self.span_c3.row(span_idx).dot(beta),
1121 })
1122 }
1123
1124 pub fn basis_span_cubic(
1125 &self,
1126 span_idx: usize,
1127 basis_idx: usize,
1128 ) -> Result<exact_kernel::LocalSpanCubic, String> {
1129 if basis_idx >= self.basis_dim {
1130 return Err(DeviationRuntimeError::InvalidInput {
1131 reason: format!(
1132 "deviation basis index {} out of range for {} coefficients",
1133 basis_idx, self.basis_dim
1134 ),
1135 }
1136 .into());
1137 }
1138 let (left, right) = self.span_interval(span_idx)?;
1139 Ok(exact_kernel::LocalSpanCubic {
1140 left,
1141 right,
1142 c0: self.span_c0[[span_idx, basis_idx]],
1143 c1: self.span_c1[[span_idx, basis_idx]],
1144 c2: self.span_c2[[span_idx, basis_idx]],
1145 c3: self.span_c3[[span_idx, basis_idx]],
1146 })
1147 }
1148
1149 pub fn basis_cubic_at(
1154 &self,
1155 basis_idx: usize,
1156 value: f64,
1157 ) -> Result<exact_kernel::LocalSpanCubic, String> {
1158 if basis_idx >= self.basis_dim {
1159 return Err(DeviationRuntimeError::InvalidInput {
1160 reason: format!(
1161 "deviation basis index {} out of range for {} coefficients",
1162 basis_idx, self.basis_dim
1163 ),
1164 }
1165 .into());
1166 }
1167 let (left_ep, right_ep) = self.support_interval()?;
1168 if value < left_ep {
1169 return Ok(exact_kernel::LocalSpanCubic {
1170 left: left_ep,
1171 right: left_ep + 1.0,
1172 c0: self.span_c0[[0, basis_idx]],
1173 c1: 0.0,
1174 c2: 0.0,
1175 c3: 0.0,
1176 });
1177 }
1178 if value > right_ep {
1179 return Ok(exact_kernel::LocalSpanCubic {
1180 left: right_ep,
1181 right: right_ep + 1.0,
1182 c0: self.right_boundary_value_row[basis_idx],
1183 c1: 0.0,
1184 c2: 0.0,
1185 c3: 0.0,
1186 });
1187 }
1188 let span_idx = self.left_biased_span_index_for(value)?;
1189 self.basis_span_cubic(span_idx, basis_idx)
1190 }
1191
1192 pub fn for_each_basis_cubic_at<F>(&self, value: f64, mut visit: F) -> Result<(), String>
1193 where
1194 F: FnMut(usize, exact_kernel::LocalSpanCubic) -> Result<(), String>,
1195 {
1196 let (left_ep, right_ep) = self.support_interval()?;
1197 if value < left_ep {
1198 for basis_idx in 0..self.basis_dim {
1199 visit(
1200 basis_idx,
1201 exact_kernel::LocalSpanCubic {
1202 left: left_ep,
1203 right: left_ep + 1.0,
1204 c0: self.span_c0[[0, basis_idx]],
1205 c1: 0.0,
1206 c2: 0.0,
1207 c3: 0.0,
1208 },
1209 )?;
1210 }
1211 return Ok(());
1212 }
1213 if value > right_ep {
1214 for basis_idx in 0..self.basis_dim {
1215 visit(
1216 basis_idx,
1217 exact_kernel::LocalSpanCubic {
1218 left: right_ep,
1219 right: right_ep + 1.0,
1220 c0: self.right_boundary_value_row[basis_idx],
1221 c1: 0.0,
1222 c2: 0.0,
1223 c3: 0.0,
1224 },
1225 )?;
1226 }
1227 return Ok(());
1228 }
1229
1230 let span_idx = self.left_biased_span_index_for(value)?;
1231 let (left, right) = self.span_interval(span_idx)?;
1232 for basis_idx in 0..self.basis_dim {
1233 visit(
1234 basis_idx,
1235 exact_kernel::LocalSpanCubic {
1236 left,
1237 right,
1238 c0: self.span_c0[[span_idx, basis_idx]],
1239 c1: self.span_c1[[span_idx, basis_idx]],
1240 c2: self.span_c2[[span_idx, basis_idx]],
1241 c3: self.span_c3[[span_idx, basis_idx]],
1242 },
1243 )?;
1244 }
1245 Ok(())
1246 }
1247
1248 pub(crate) fn local_cubic_at(
1253 &self,
1254 beta: &Array1<f64>,
1255 value: f64,
1256 ) -> Result<exact_kernel::LocalSpanCubic, String> {
1257 self.validate_beta_shape(beta, "deviation local cubic")?;
1258 let (left_ep, right_ep) = self.support_interval()?;
1259 if value < left_ep {
1260 return Ok(exact_kernel::LocalSpanCubic {
1261 left: left_ep,
1262 right: left_ep + 1.0,
1263 c0: self.left_tail_value(beta),
1264 c1: 0.0,
1265 c2: 0.0,
1266 c3: 0.0,
1267 });
1268 }
1269 if value > right_ep {
1270 return Ok(exact_kernel::LocalSpanCubic {
1271 left: right_ep,
1272 right: right_ep + 1.0,
1273 c0: self.right_tail_value(beta),
1274 c1: 0.0,
1275 c2: 0.0,
1276 c3: 0.0,
1277 });
1278 }
1279 let span_idx = self.left_biased_span_index_for(value)?;
1280 self.local_cubic_on_span(beta, span_idx)
1281 }
1282
1283 pub(super) fn left_tail_value(&self, beta: &Array1<f64>) -> f64 {
1288 self.span_c0.row(0).dot(beta)
1289 }
1290
1291 pub(super) fn right_tail_value(&self, beta: &Array1<f64>) -> f64 {
1294 self.right_boundary_value_row.dot(beta)
1295 }
1296
1297 pub(crate) fn value_basis_l1_sup_norm(&self) -> f64 {
1306 let mut total = 0.0;
1307 for basis_idx in 0..self.basis_dim {
1308 let mut col_sup = self.span_c0[[0, basis_idx]]
1309 .abs()
1310 .max(self.right_boundary_value_row[basis_idx].abs());
1311 for span_idx in 0..self.span_count() {
1312 let left = self.endpoint_points[span_idx];
1313 let right = self.endpoint_points[span_idx + 1];
1314 let width = right - left;
1315 if !width.is_finite() || width <= 0.0 {
1316 continue;
1317 }
1318 let c0 = self.span_c0[[span_idx, basis_idx]];
1319 let c1 = self.span_c1[[span_idx, basis_idx]];
1320 let c2 = self.span_c2[[span_idx, basis_idx]];
1321 let c3 = self.span_c3[[span_idx, basis_idx]];
1322 let eval_abs = |t: f64| (c0 + c1 * t + c2 * t * t + c3 * t * t * t).abs();
1323 col_sup = col_sup.max(eval_abs(0.0)).max(eval_abs(width));
1324 let a = 3.0 * c3;
1325 let b = 2.0 * c2;
1326 let c = c1;
1327 if a.abs() <= f64::EPSILON {
1328 if b.abs() > f64::EPSILON {
1329 let t = -c / b;
1330 if t > 0.0 && t < width {
1331 col_sup = col_sup.max(eval_abs(t));
1332 }
1333 }
1334 } else {
1335 let disc = b * b - 4.0 * a * c;
1336 if disc >= 0.0 {
1337 let sqrt_disc = disc.sqrt();
1338 for t in [(-b - sqrt_disc) / (2.0 * a), (-b + sqrt_disc) / (2.0 * a)] {
1339 if t > 0.0 && t < width {
1340 col_sup = col_sup.max(eval_abs(t));
1341 }
1342 }
1343 }
1344 }
1345 }
1346 total += col_sup;
1347 }
1348 total
1349 }
1350
1351 pub(super) fn support_interval(&self) -> Result<(f64, f64), String> {
1354 match (self.endpoint_points.first(), self.endpoint_points.last()) {
1355 (Some(&left), Some(&right)) => Ok((left, right)),
1356 _ => Err(DeviationRuntimeError::InvalidInput {
1357 reason: "deviation runtime is missing monotonicity support points".to_string(),
1358 }
1359 .into()),
1360 }
1361 }
1362
1363 pub(crate) fn exact_monotonicity_min_slack(&self, beta: &Array1<f64>) -> Result<f64, String> {
1364 if beta.len() != self.basis_dim {
1365 return Err(DeviationRuntimeError::DimensionMismatch {
1366 reason: format!(
1367 "deviation monotonicity length mismatch: got {}, expected {}",
1368 beta.len(),
1369 self.basis_dim
1370 ),
1371 }
1372 .into());
1373 }
1374 if beta.iter().any(|value| !value.is_finite()) {
1375 let bad = beta
1376 .iter()
1377 .enumerate()
1378 .find(|(_, value)| !value.is_finite())
1379 .map(|(idx, value)| format!("deviation coefficient {idx} is non-finite ({value})"))
1380 .unwrap_or_else(|| "deviation coefficient is non-finite".to_string());
1381 return Err(DeviationRuntimeError::InvalidInput { reason: bad }.into());
1382 }
1383
1384 let mut min_slack = f64::INFINITY;
1385 for span_idx in 0..self.span_count() {
1386 let left = self.endpoint_points[span_idx];
1387 let right = self.endpoint_points[span_idx + 1];
1388 let width = right - left;
1389 if !width.is_finite() || width <= 0.0 {
1390 continue;
1391 }
1392 let c1 = self.span_c1.row(span_idx).dot(beta);
1393 let c2 = self.span_c2.row(span_idx).dot(beta);
1394 let c3 = self.span_c3.row(span_idx).dot(beta);
1395 let d1_left = c1;
1396 let d1_right = c1 + 2.0 * c2 * width + 3.0 * c3 * width * width;
1397 let d2_left = 2.0 * c2;
1398 let d3 = 6.0 * c3;
1399 let left_slack = 1.0 + d1_left - self.monotonicity_eps;
1400 let right_slack = 1.0 + d1_right - self.monotonicity_eps;
1401 min_slack = min_slack.min(left_slack.min(right_slack));
1402
1403 if d3 > 0.0 {
1404 let t_star = -d2_left / d3;
1405 if t_star > 0.0 && t_star < width {
1406 let interior = 1.0 + d1_left + d2_left * t_star + 0.5 * d3 * t_star * t_star
1407 - self.monotonicity_eps;
1408 min_slack = min_slack.min(interior);
1409 }
1410 }
1411 }
1412 if min_slack.is_finite() {
1413 Ok(min_slack)
1414 } else {
1415 Err(DeviationRuntimeError::NumericalFailure {
1416 reason: "deviation monotonicity slack computation produced no active spans"
1417 .to_string(),
1418 }
1419 .into())
1420 }
1421 }
1422
1423 pub(crate) fn monotonicity_feasible(
1424 &self,
1425 beta: &Array1<f64>,
1426 context: &str,
1427 ) -> Result<(), String> {
1428 let slack = self.exact_monotonicity_min_slack(beta)?;
1429 if slack >= MONOTONICITY_SLACK_ROUNDOFF_TOL {
1430 Ok(())
1431 } else {
1432 let (left, right) = self.support_interval()?;
1433 Err(DeviationRuntimeError::NumericalFailure {
1434 reason: format!(
1435 "{context} violates exact monotonicity on [{left:.6}, {right:.6}] (minimum derivative slack {slack:.3e}, eps={:.3e})",
1436 self.monotonicity_eps
1437 ),
1438 }
1439 .into())
1440 }
1441 }
1442}