1use gam_linalg::faer_ndarray::{FaerSvd, fast_ab};
2use gam_linalg::matrix::{DenseDesignMatrix, DenseDesignOperator, DesignMatrix, LinearOperator};
3use ndarray::{Array1, Array2, ArrayViewMut2, s};
4use std::ops::Range;
5use std::sync::Arc;
6
7#[derive(Debug, Clone)]
13pub enum ScaleDesignError {
14 InvalidWeights { reason: String },
17 IncompatibleDimensions { reason: String },
19 NonFiniteInput { reason: String },
22 DegenerateDesign { reason: String },
25 RowMaterializationFailed { reason: String },
27 SvdFailed { reason: String },
30}
31
32impl_reason_error_boilerplate! {
33 ScaleDesignError {
34 InvalidWeights,
35 IncompatibleDimensions,
36 NonFiniteInput,
37 DegenerateDesign,
38 RowMaterializationFailed,
39 SvdFailed,
40 }
41}
42
43const COLUMN_TOL: f64 = 1e-12;
44const SCALE_DESIGN_TARGET_CHUNK_BYTES: usize = 8 * 1024 * 1024;
45const SCALE_PROJECTION_REPLAY_RCOND_FLOOR: f64 = 1e-8;
52const SCALE_PROJECTION_LEVERAGE_AMPLIFICATION: f64 = 1.0e8;
61const SCALE_OPERATOR_MATRIX_FREE_PCG_THRESHOLD: usize = 1_000_000;
68
69#[derive(Clone, Debug)]
70pub struct ScaleDeviationTransform {
71 pub projection_coef: Array2<f64>,
72 pub weighted_column_mean: Array1<f64>,
73 pub rescale: Array1<f64>,
74 pub non_intercept_start: usize,
75 pub projection_ridge_alpha: f64,
79}
80
81impl ScaleDeviationTransform {
82 pub fn identity(p_primary: usize, p_noise: usize, non_intercept_start: usize) -> Self {
98 ScaleDeviationTransform {
99 projection_coef: Array2::<f64>::zeros((p_primary, p_noise)),
100 weighted_column_mean: Array1::<f64>::zeros(p_noise),
101 rescale: Array1::<f64>::ones(p_noise),
102 non_intercept_start,
103 projection_ridge_alpha: 0.0,
104 }
105 }
106}
107
108pub fn scale_transform_from_payload(
114 projection: &Option<Vec<Vec<f64>>>,
115 center: &Option<Vec<f64>>,
116 scale: &Option<Vec<f64>>,
117 non_intercept_start: Option<usize>,
118 projection_ridge_alpha: Option<f64>,
119) -> Result<Option<ScaleDeviationTransform>, String> {
120 scale_transform_from_payload_typed(
121 projection,
122 center,
123 scale,
124 non_intercept_start,
125 projection_ridge_alpha,
126 )
127 .map_err(|e| e.to_string())
128}
129
130fn scale_transform_from_payload_typed(
131 projection: &Option<Vec<Vec<f64>>>,
132 center: &Option<Vec<f64>>,
133 scale: &Option<Vec<f64>>,
134 non_intercept_start: Option<usize>,
135 projection_ridge_alpha: Option<f64>,
136) -> Result<Option<ScaleDeviationTransform>, ScaleDesignError> {
137 match (projection, center, scale, non_intercept_start) {
138 (None, None, None, None) => Ok(None),
139 (Some(projection), Some(center), Some(scale), Some(non_intercept_start)) => {
140 let rows = projection.len();
141 let cols = center.len();
142 if cols != scale.len() {
143 return Err(ScaleDesignError::IncompatibleDimensions {
144 reason: "saved scale transform center/scale length mismatch".to_string(),
145 });
146 }
147 if rows == 0 && cols > 0 {
148 return Err(ScaleDesignError::DegenerateDesign {
149 reason: "saved scale transform projection has zero rows".to_string(),
150 });
151 }
152 let mut projection_coef = Array2::<f64>::zeros((rows, cols));
153 for (i, row) in projection.iter().enumerate() {
154 if row.len() != cols {
155 return Err(ScaleDesignError::IncompatibleDimensions {
156 reason: "saved scale transform projection width mismatch".to_string(),
157 });
158 }
159 for (j, &value) in row.iter().enumerate() {
160 projection_coef[[i, j]] = value;
161 }
162 }
163 let Some(projection_ridge_alpha) = projection_ridge_alpha else {
164 return Err(ScaleDesignError::DegenerateDesign {
165 reason:
166 "saved scale transform payload is missing projection_ridge_alpha; refit"
167 .to_string(),
168 });
169 };
170 if !projection_ridge_alpha.is_finite() || projection_ridge_alpha < 0.0 {
171 return Err(ScaleDesignError::NonFiniteInput {
172 reason: format!(
173 "saved scale transform projection_ridge_alpha must be finite and non-negative, got {projection_ridge_alpha}"
174 ),
175 });
176 }
177 Ok(Some(ScaleDeviationTransform {
178 projection_coef,
179 weighted_column_mean: Array1::from_vec(center.clone()),
180 rescale: Array1::from_vec(scale.clone()),
181 non_intercept_start,
182 projection_ridge_alpha,
183 }))
184 }
185 _ => Err(ScaleDesignError::DegenerateDesign {
186 reason: "saved scale transform payload is only partially populated; refit".to_string(),
187 }),
188 }
189}
190
191#[derive(Clone, Copy)]
192enum ScaleDesignMatrixRef<'a> {
193 Dense(&'a Array2<f64>),
194 Design(&'a DesignMatrix),
195}
196
197impl ScaleDesignMatrixRef<'_> {
198 #[inline]
199 fn nrows(self) -> usize {
200 match self {
201 Self::Dense(matrix) => matrix.nrows(),
202 Self::Design(matrix) => matrix.nrows(),
203 }
204 }
205
206 #[inline]
207 fn ncols(self) -> usize {
208 match self {
209 Self::Dense(matrix) => matrix.ncols(),
210 Self::Design(matrix) => matrix.ncols(),
211 }
212 }
213
214 fn row_chunk(self, rows: Range<usize>) -> Result<Array2<f64>, ScaleDesignError> {
215 match self {
216 Self::Dense(matrix) => Ok(matrix.slice(s![rows, ..]).to_owned()),
217 Self::Design(matrix) => {
218 matrix
219 .try_row_chunk(rows)
220 .map_err(|e| ScaleDesignError::RowMaterializationFailed {
221 reason: format!("scale deviation row materialization failed: {e}"),
222 })
223 }
224 }
225 }
226}
227
228pub fn infer_non_intercept_start(design: &Array2<f64>, weights: &Array1<f64>) -> usize {
229 infer_non_intercept_start_impl(
230 ScaleDesignMatrixRef::Dense(design),
231 weights,
232 "weighted column stats row mismatch".to_string(),
233 )
234 .unwrap_or(0)
235}
236
237fn dim_err(reason: impl Into<String>) -> ScaleDesignError {
238 ScaleDesignError::IncompatibleDimensions {
239 reason: reason.into(),
240 }
241}
242
243pub fn build_scale_deviation_transform(
244 primary_design: &Array2<f64>,
245 noise_design: &Array2<f64>,
246 weights: &Array1<f64>,
247 non_intercept_start: usize,
248) -> Result<ScaleDeviationTransform, String> {
249 build_scale_deviation_transform_impl(
250 ScaleDesignMatrixRef::Dense(primary_design),
251 ScaleDesignMatrixRef::Dense(noise_design),
252 weights,
253 non_intercept_start,
254 "scale deviation transform row mismatch",
255 )
256 .map_err(|e| e.to_string())
257}
258
259pub fn apply_scale_deviation_transform(
260 primary_design: &Array2<f64>,
261 rawnoise_design: &Array2<f64>,
262 transform: &ScaleDeviationTransform,
263) -> Result<Array2<f64>, String> {
264 apply_scale_deviation_transform_typed(primary_design, rawnoise_design, transform)
265 .map_err(|e| e.to_string())
266}
267
268fn apply_scale_deviation_transform_typed(
269 primary_design: &Array2<f64>,
270 rawnoise_design: &Array2<f64>,
271 transform: &ScaleDeviationTransform,
272) -> Result<Array2<f64>, ScaleDesignError> {
273 if primary_design.nrows() != rawnoise_design.nrows() {
274 return Err(dim_err("scale deviation apply row mismatch"));
275 }
276 if primary_design.ncols() != transform.projection_coef.nrows()
277 || rawnoise_design.ncols() != transform.projection_coef.ncols()
278 {
279 return Err(dim_err("scale deviation apply column mismatch"));
280 }
281 let n = rawnoise_design.nrows();
282 let p_primary = primary_design.ncols();
283 let p_noise = rawnoise_design.ncols();
284 let chunk_rows = scale_design_row_chunk_size(n, p_primary.max(p_noise));
285 let mut out = Array2::<f64>::zeros((n, p_noise));
286 for start in (0..n).step_by(chunk_rows) {
287 let end = (start + chunk_rows).min(n);
288 let primary_chunk = primary_design.slice(s![start..end, ..]).to_owned();
289 let noise_chunk = rawnoise_design.slice(s![start..end, ..]).to_owned();
290 let chunk = apply_scale_deviation_reparam_chunk(&primary_chunk, &noise_chunk, transform);
291 out.slice_mut(s![start..end, ..]).assign(&chunk);
292 }
293 Ok(out)
294}
295
296#[derive(Clone)]
297struct ScaleDeviationOperator {
298 primary_design: DesignMatrix,
299 rawnoise_design: DesignMatrix,
300 transform: ScaleDeviationTransform,
301 chunk_rows: usize,
302}
303
304impl ScaleDeviationOperator {
305 fn row_chunk(&self, rows: Range<usize>) -> Result<Array2<f64>, ScaleDesignError> {
306 let primary_chunk = self
307 .primary_design
308 .try_row_chunk(rows.clone())
309 .map_err(|e| ScaleDesignError::RowMaterializationFailed {
310 reason: format!("scale deviation operator primary chunk: {e}"),
311 })?;
312 let noise_chunk = self.rawnoise_design.try_row_chunk(rows).map_err(|e| {
313 ScaleDesignError::RowMaterializationFailed {
314 reason: format!("scale deviation operator noise chunk: {e}"),
315 }
316 })?;
317 Ok(apply_scale_deviation_reparam_chunk(
318 &primary_chunk,
319 &noise_chunk,
320 &self.transform,
321 ))
322 }
323}
324
325impl LinearOperator for ScaleDeviationOperator {
326 fn nrows(&self) -> usize {
327 self.rawnoise_design.nrows()
328 }
329
330 fn ncols(&self) -> usize {
331 self.rawnoise_design.ncols()
332 }
333
334 fn apply(&self, vector: &Array1<f64>) -> Array1<f64> {
335 assert_eq!(vector.len(), self.ncols());
336 let n = self.nrows();
337 let mut out = Array1::<f64>::zeros(n);
338 for start in (0..n).step_by(self.chunk_rows) {
339 let end = (start + self.chunk_rows).min(n);
340 let chunk = self
341 .row_chunk(start..end)
342 .expect("scale deviation operator row chunk failed");
343 out.slice_mut(s![start..end]).assign(&chunk.dot(vector));
344 }
345 out
346 }
347
348 fn apply_transpose(&self, vector: &Array1<f64>) -> Array1<f64> {
349 assert_eq!(vector.len(), self.nrows());
350 let n = self.nrows();
351 let p = self.ncols();
352 let mut out = Array1::<f64>::zeros(p);
353 for start in (0..n).step_by(self.chunk_rows) {
354 let end = (start + self.chunk_rows).min(n);
355 let chunk = self
356 .row_chunk(start..end)
357 .expect("scale deviation operator row chunk failed");
358 out += &chunk.t().dot(&vector.slice(s![start..end]).to_owned());
359 }
360 out
361 }
362
363 fn diag_xtw_x(&self, weights: &Array1<f64>) -> Result<Array2<f64>, String> {
364 if weights.len() != self.nrows() {
365 return Err(dim_err(format!(
366 "scale deviation operator XtWX weight mismatch: weights={}, rows={}",
367 weights.len(),
368 self.nrows()
369 ))
370 .to_string());
371 }
372 let n = self.nrows();
373 let p = self.ncols();
374 let mut out = Array2::<f64>::zeros((p, p));
375 for start in (0..n).step_by(self.chunk_rows) {
376 let end = (start + self.chunk_rows).min(n);
377 let chunk = self.row_chunk(start..end).map_err(|e| e.to_string())?;
378 for local in 0..chunk.nrows() {
379 let w = weights[start + local].max(0.0);
380 if w == 0.0 {
381 continue;
382 }
383 for a in 0..p {
384 let xa = chunk[[local, a]];
385 for b in a..p {
386 let value = w * xa * chunk[[local, b]];
387 out[[a, b]] += value;
388 if a != b {
389 out[[b, a]] += value;
390 }
391 }
392 }
393 }
394 }
395 Ok(out)
396 }
397
398 fn uses_matrix_free_pcg(&self) -> bool {
399 self.primary_design
400 .nrows()
401 .saturating_mul(self.rawnoise_design.ncols())
402 > SCALE_OPERATOR_MATRIX_FREE_PCG_THRESHOLD
403 }
404}
405
406impl DenseDesignOperator for ScaleDeviationOperator {
407 fn row_chunk_into(
408 &self,
409 rows: Range<usize>,
410 mut out: ArrayViewMut2<'_, f64>,
411 ) -> Result<(), gam_runtime::resource::MatrixMaterializationError> {
412 let chunk = self.row_chunk(rows).map_err(|err| {
413 gam_runtime::resource::MatrixMaterializationError::RowMaterializationFailed {
414 context: "ScaleDeviationOperator::row_chunk_into",
415 reason: err.to_string(),
416 }
417 })?;
418 out.assign(&chunk);
419 Ok(())
420 }
421
422 fn to_dense(&self) -> Array2<f64> {
423 let n = self.nrows();
424 let p = self.ncols();
425 let mut out = Array2::<f64>::zeros((n, p));
426 for start in (0..n).step_by(self.chunk_rows) {
427 let end = (start + self.chunk_rows).min(n);
428 let chunk = self
429 .row_chunk(start..end)
430 .expect("scale deviation operator row chunk failed");
431 out.slice_mut(s![start..end, ..]).assign(&chunk);
432 }
433 out
434 }
435}
436
437#[derive(Debug)]
438struct WeightedColumnStats {
439 weighted_sum: Array1<f64>,
440 weighted_sum_sq: Array1<f64>,
441 total_weight: f64,
442}
443
444fn validate_scale_weights(weights: &Array1<f64>) -> Result<f64, ScaleDesignError> {
445 let mut total_weight = 0.0;
446 for (idx, &w) in weights.iter().enumerate() {
447 if !w.is_finite() {
448 return Err(ScaleDesignError::NonFiniteInput {
449 reason: format!("scale deviation weight {idx} is not finite"),
450 });
451 }
452 if w < 0.0 {
453 return Err(ScaleDesignError::InvalidWeights {
454 reason: format!(
455 "scale deviation requires non-negative weights, got {w} at index {idx}"
456 ),
457 });
458 }
459 total_weight += w;
460 }
461 if !total_weight.is_finite() || total_weight <= 0.0 {
462 return Err(ScaleDesignError::InvalidWeights {
463 reason: "scale deviation requires positive finite total weight".to_string(),
464 });
465 }
466 Ok(total_weight)
467}
468
469fn scale_design_row_chunk_size(nrows: usize, max_cols: usize) -> usize {
470 (SCALE_DESIGN_TARGET_CHUNK_BYTES / (max_cols.max(1) * std::mem::size_of::<f64>()))
471 .max(1)
472 .min(nrows.max(1))
473}
474
475fn weighted_column_stats(
476 design: ScaleDesignMatrixRef<'_>,
477 weights: &Array1<f64>,
478 row_mismatch_error: String,
479) -> Result<WeightedColumnStats, ScaleDesignError> {
480 if design.nrows() != weights.len() {
481 return Err(dim_err(row_mismatch_error));
482 }
483 let total_weight = validate_scale_weights(weights)?;
484 let p = design.ncols();
485 let mut weighted_sum = Array1::<f64>::zeros(p);
486 let mut weighted_sum_sq = Array1::<f64>::zeros(p);
487 let chunk_rows = scale_design_row_chunk_size(design.nrows(), p);
488 for start in (0..design.nrows()).step_by(chunk_rows) {
489 let end = (start + chunk_rows).min(design.nrows());
490 let chunk = design.row_chunk(start..end)?;
491 for local in 0..(end - start) {
492 let w = weights[start + local];
493 if w == 0.0 {
494 continue;
495 }
496 for j in 0..p {
497 let x = chunk[[local, j]];
498 weighted_sum[j] += w * x;
499 weighted_sum_sq[j] += w * x * x;
500 }
501 }
502 }
503 Ok(WeightedColumnStats {
504 weighted_sum,
505 weighted_sum_sq,
506 total_weight,
507 })
508}
509
510fn infer_non_intercept_start_impl(
511 design: ScaleDesignMatrixRef<'_>,
512 weights: &Array1<f64>,
513 row_mismatch_error: String,
514) -> Result<usize, ScaleDesignError> {
515 let stats = weighted_column_stats(design, weights, row_mismatch_error)?;
516 let mut end = 0;
517 for j in 0..stats.weighted_sum.len() {
518 let centered_ss = stats.weighted_sum_sq[j]
519 - stats.weighted_sum[j] * stats.weighted_sum[j] / stats.total_weight;
520 if centered_ss <= COLUMN_TOL {
521 end = j + 1;
522 } else {
523 break;
524 }
525 }
526 Ok(end)
527}
528
529fn build_weighted_primary_design(
530 primary_design: ScaleDesignMatrixRef<'_>,
531 sqrtw: &Array1<f64>,
532 chunk_rows: usize,
533) -> Result<Array2<f64>, ScaleDesignError> {
534 let n = primary_design.nrows();
535 let p_primary = primary_design.ncols();
536 let mut wx = Array2::<f64>::zeros((n, p_primary));
537 for start in (0..n).step_by(chunk_rows) {
538 let end = (start + chunk_rows).min(n);
539 let x_chunk = primary_design.row_chunk(start..end)?;
540 for local in 0..(end - start) {
541 let sw = sqrtw[start + local];
542 for col in 0..p_primary {
543 wx[[start + local, col]] = sw * x_chunk[[local, col]];
544 }
545 }
546 }
547 Ok(wx)
548}
549
550fn choose_scale_projection_ridge_alpha(singular: &[f64]) -> f64 {
559 if singular.is_empty() {
560 return 0.0;
561 }
562 let sigma_max = singular.iter().copied().fold(0.0_f64, f64::max);
563 if !sigma_max.is_finite() || sigma_max <= 0.0 {
564 return 0.0;
565 }
566 let derived_tol = sigma_max / SCALE_PROJECTION_LEVERAGE_AMPLIFICATION;
567 let truncation_tol = derived_tol.max(SCALE_PROJECTION_REPLAY_RCOND_FLOOR * sigma_max);
568 truncation_tol * truncation_tol
569}
570
571fn solve_scale_projection(
572 primary_design: ScaleDesignMatrixRef<'_>,
573 noise_design: ScaleDesignMatrixRef<'_>,
574 weights: &Array1<f64>,
575 first_active: usize,
576 chunk_rows: usize,
577) -> Result<(Array2<f64>, f64), ScaleDesignError> {
578 let n = primary_design.nrows();
579 let p_primary = primary_design.ncols();
580 let p_noise = noise_design.ncols();
581 let mut projection_coef = Array2::<f64>::zeros((p_primary, p_noise));
582 let active_cols = p_noise.saturating_sub(first_active);
583
584 if active_cols == 0 || p_primary == 0 {
585 return Ok((projection_coef, 0.0));
586 }
587
588 let sqrtw = weights.mapv(f64::sqrt);
589 let wx = build_weighted_primary_design(primary_design, &sqrtw, chunk_rows)?;
590 let (u_opt, singular, vt_opt) =
594 wx.svd(true, true)
595 .map_err(|e| ScaleDesignError::SvdFailed {
596 reason: format!("scale projection SVD failed: {e:?}"),
597 })?;
598 let (Some(u), Some(vt)) = (u_opt, vt_opt) else {
599 return Err(ScaleDesignError::SvdFailed {
600 reason: "scale projection SVD did not return singular vectors".to_string(),
601 });
602 };
603 let alpha = choose_scale_projection_ridge_alpha(singular.as_slice().unwrap_or(&[]));
604 let rank = singular.len();
605 if rank == 0 {
606 return Ok((projection_coef, alpha));
607 }
608 let cutoff = alpha.sqrt();
618 let mut filter = Array1::<f64>::zeros(rank);
619 for k in 0..rank {
620 let s = singular[k];
621 filter[k] = if s > cutoff && s > 0.0 { 1.0 / s } else { 0.0 };
622 }
623
624 let chunk_cols = (SCALE_DESIGN_TARGET_CHUNK_BYTES / (n.max(1) * std::mem::size_of::<f64>()))
625 .max(1)
626 .min(active_cols);
627
628 for chunk_start in (0..active_cols).step_by(chunk_cols) {
629 let width = (active_cols - chunk_start).min(chunk_cols);
630 let mut rhs = Array2::<f64>::zeros((n, width));
631 for start in (0..n).step_by(chunk_rows) {
632 let end = (start + chunk_rows).min(n);
633 let noise_chunk = noise_design.row_chunk(start..end)?;
634 for local in 0..(end - start) {
635 let sw = sqrtw[start + local];
636 for col in 0..width {
637 rhs[[start + local, col]] =
638 sw * noise_chunk[[local, first_active + chunk_start + col]];
639 }
640 }
641 }
642
643 let mut t = u.t().dot(&rhs);
645 for k in 0..rank {
647 let f = filter[k];
648 for col in 0..width {
649 t[[k, col]] *= f;
650 }
651 }
652 let block = vt.t().dot(&t);
655 for col in 0..width {
656 for row in 0..p_primary {
657 projection_coef[[row, first_active + chunk_start + col]] = block[[row, col]];
658 }
659 }
660 }
661
662 Ok((projection_coef, alpha))
663}
664
665fn apply_projection_chunk(
666 primary_chunk: &Array2<f64>,
667 projection_coef: &Array2<f64>,
668 first_active: usize,
669) -> Array2<f64> {
670 if first_active >= projection_coef.ncols() {
671 Array2::<f64>::zeros((primary_chunk.nrows(), 0))
672 } else {
673 fast_ab(
674 primary_chunk,
675 &projection_coef.slice(s![.., first_active..]).to_owned(),
676 )
677 }
678}
679
680fn build_scale_deviation_transform_impl(
681 primary_design: ScaleDesignMatrixRef<'_>,
682 noise_design: ScaleDesignMatrixRef<'_>,
683 weights: &Array1<f64>,
684 non_intercept_start: usize,
685 row_mismatch_error: &str,
686) -> Result<ScaleDeviationTransform, ScaleDesignError> {
687 if primary_design.nrows() != noise_design.nrows() || weights.len() != noise_design.nrows() {
688 return Err(dim_err(row_mismatch_error.to_string()));
689 }
690 validate_scale_weights(weights)?;
691
692 let n = primary_design.nrows();
693 let p_primary = primary_design.ncols();
694 let p_noise = noise_design.ncols();
695 let first_active = non_intercept_start.min(p_noise);
696 let chunk_rows = scale_design_row_chunk_size(n, p_primary.max(p_noise));
697 let (projection_coef, projection_ridge_alpha) = solve_scale_projection(
698 primary_design,
699 noise_design,
700 weights,
701 first_active,
702 chunk_rows,
703 )?;
704 let mut weighted_column_mean = Array1::<f64>::zeros(p_noise);
705 let mut rescale = Array1::<f64>::ones(p_noise);
706 let active_cols = p_noise - first_active;
707
708 if active_cols > 0 {
709 let projection_only_transform = ScaleDeviationTransform {
710 projection_coef: projection_coef.clone(),
711 weighted_column_mean: Array1::<f64>::zeros(p_noise),
712 rescale: Array1::<f64>::ones(p_noise),
713 non_intercept_start,
714 projection_ridge_alpha,
715 };
716 let mut w_sum = 0.0;
717 let mut w_resid_sum = Array1::<f64>::zeros(active_cols);
718 let mut w_noise_sum = Array1::<f64>::zeros(active_cols);
719
720 for start in (0..n).step_by(chunk_rows) {
721 let end = (start + chunk_rows).min(n);
722 let x_chunk = primary_design.row_chunk(start..end)?;
723 let noise_chunk = noise_design.row_chunk(start..end)?;
724 let resid_chunk = apply_scale_deviation_reparam_chunk(
725 &x_chunk,
726 &noise_chunk,
727 &projection_only_transform,
728 );
729 for local in 0..(end - start) {
730 let w = weights[start + local];
731 if w == 0.0 {
732 continue;
733 }
734 w_sum += w;
735 for jj in 0..active_cols {
736 let nij = noise_chunk[[local, first_active + jj]];
737 w_noise_sum[jj] += w * nij;
738 w_resid_sum[jj] += w * resid_chunk[[local, first_active + jj]];
739 }
740 }
741 }
742
743 if !w_sum.is_finite() || w_sum <= 0.0 {
744 return Err(ScaleDesignError::InvalidWeights {
745 reason: "scale deviation requires positive finite total weight".to_string(),
746 });
747 }
748
749 let resid_center = w_resid_sum.mapv(|sum| sum / w_sum);
750 let noise_mean = w_noise_sum.mapv(|sum| sum / w_sum);
751 let mut orig_css = Array1::<f64>::zeros(active_cols);
752 let mut resid_css = Array1::<f64>::zeros(active_cols);
753
754 for start in (0..n).step_by(chunk_rows) {
755 let end = (start + chunk_rows).min(n);
756 let x_chunk = primary_design.row_chunk(start..end)?;
757 let noise_chunk = noise_design.row_chunk(start..end)?;
758 let resid_chunk = apply_scale_deviation_reparam_chunk(
759 &x_chunk,
760 &noise_chunk,
761 &projection_only_transform,
762 );
763 for local in 0..(end - start) {
764 let w = weights[start + local];
765 if w == 0.0 {
766 continue;
767 }
768 for jj in 0..active_cols {
769 let nij = noise_chunk[[local, first_active + jj]];
770 let d_orig = nij - noise_mean[jj];
771 orig_css[jj] += w * d_orig * d_orig;
772 let d_resid = resid_chunk[[local, first_active + jj]] - resid_center[jj];
773 resid_css[jj] += w * d_resid * d_resid;
774 }
775 }
776 }
777
778 for jj in 0..active_cols {
779 let j = first_active + jj;
780 let scale = if resid_css[jj].is_finite()
781 && resid_css[jj] > COLUMN_TOL
782 && orig_css[jj].is_finite()
783 && orig_css[jj] > COLUMN_TOL
784 {
785 (orig_css[jj] / resid_css[jj]).sqrt()
786 } else {
787 1.0
788 };
789 weighted_column_mean[j] = resid_center[jj];
790 rescale[j] = scale;
791 }
792 }
793
794 Ok(ScaleDeviationTransform {
795 projection_coef,
796 weighted_column_mean,
797 rescale,
798 non_intercept_start,
799 projection_ridge_alpha,
800 })
801}
802
803pub fn infer_non_intercept_start_design(
804 design: &DesignMatrix,
805 weights: &Array1<f64>,
806) -> Result<usize, String> {
807 infer_non_intercept_start_impl(
808 ScaleDesignMatrixRef::Design(design),
809 weights,
810 format!(
811 "weighted column stats row mismatch: design has {} rows, weights have {} entries",
812 design.nrows(),
813 weights.len()
814 ),
815 )
816 .map_err(|e| e.to_string())
817}
818
819pub fn build_scale_deviation_transform_design(
820 primary_design: &DesignMatrix,
821 noise_design: &DesignMatrix,
822 weights: &Array1<f64>,
823 non_intercept_start: usize,
824) -> Result<ScaleDeviationTransform, String> {
825 build_scale_deviation_transform_impl(
826 ScaleDesignMatrixRef::Design(primary_design),
827 ScaleDesignMatrixRef::Design(noise_design),
828 weights,
829 non_intercept_start,
830 "scale deviation transform design row mismatch",
831 )
832 .map_err(|e| e.to_string())
833}
834
835fn apply_scale_deviation_reparam_chunk(
843 primary_chunk: &Array2<f64>,
844 noise_chunk: &Array2<f64>,
845 transform: &ScaleDeviationTransform,
846) -> Array2<f64> {
847 let rows = noise_chunk.nrows();
848 let p_noise = noise_chunk.ncols();
849 let first_active = transform.non_intercept_start.min(p_noise);
850 let mut out = Array2::<f64>::zeros((rows, p_noise));
851
852 for j in 0..first_active {
854 for i in 0..rows {
855 out[[i, j]] = noise_chunk[[i, j]];
856 }
857 }
858
859 if first_active < p_noise {
861 let fitted =
862 apply_projection_chunk(primary_chunk, &transform.projection_coef, first_active);
863 for j in first_active..p_noise {
864 let jj = j - first_active;
865 let scale = transform.rescale[j];
866 let center = transform.weighted_column_mean[j];
867 for i in 0..rows {
868 out[[i, j]] = (noise_chunk[[i, j]] - fitted[[i, jj]] - center) * scale;
869 }
870 }
871 }
872
873 out
874}
875
876pub fn build_scale_deviation_operator(
877 primary_design: DesignMatrix,
878 rawnoise_design: DesignMatrix,
879 transform: &ScaleDeviationTransform,
880) -> Result<DesignMatrix, String> {
881 build_scale_deviation_operator_typed(primary_design, rawnoise_design, transform)
882 .map_err(|e| e.to_string())
883}
884
885fn build_scale_deviation_operator_typed(
886 primary_design: DesignMatrix,
887 rawnoise_design: DesignMatrix,
888 transform: &ScaleDeviationTransform,
889) -> Result<DesignMatrix, ScaleDesignError> {
890 if primary_design.nrows() != rawnoise_design.nrows() {
891 return Err(dim_err(format!(
892 "scale deviation operator row mismatch: primary rows={}, noise rows={}",
893 primary_design.nrows(),
894 rawnoise_design.nrows()
895 )));
896 }
897 if primary_design.ncols() != transform.projection_coef.nrows()
898 || rawnoise_design.ncols() != transform.projection_coef.ncols()
899 {
900 return Err(dim_err(format!(
901 "scale deviation operator column mismatch: primary cols={}, noise cols={}, transform is {}x{}",
902 primary_design.ncols(),
903 rawnoise_design.ncols(),
904 transform.projection_coef.nrows(),
905 transform.projection_coef.ncols()
906 )));
907 }
908 let n = rawnoise_design.nrows();
909 let p_primary = primary_design.ncols();
910 let p_noise = rawnoise_design.ncols();
911 let chunk_rows = scale_design_row_chunk_size(n, p_primary.max(p_noise));
912 Ok(DesignMatrix::Dense(DenseDesignMatrix::from(Arc::new(
913 ScaleDeviationOperator {
914 primary_design,
915 rawnoise_design,
916 transform: transform.clone(),
917 chunk_rows,
918 },
919 ))))
920}
921
922#[cfg(test)]
923mod tests {
924 use super::*;
925 use gam_linalg::matrix::DesignMatrix;
926
927 fn assert_matrix_close(lhs: &Array2<f64>, rhs: &Array2<f64>, tol: f64, label: &str) {
928 assert_eq!(
929 lhs.dim(),
930 rhs.dim(),
931 "{label} shape mismatch: left {:?}, right {:?}",
932 lhs.dim(),
933 rhs.dim()
934 );
935 for i in 0..lhs.nrows() {
936 for j in 0..lhs.ncols() {
937 assert!(
938 (lhs[[i, j]] - rhs[[i, j]]).abs() <= tol,
939 "{label} mismatch at ({i}, {j}): {} vs {}",
940 lhs[[i, j]],
941 rhs[[i, j]]
942 );
943 }
944 }
945 }
946
947 fn assert_transform_close(
948 lhs: &ScaleDeviationTransform,
949 rhs: &ScaleDeviationTransform,
950 tol: f64,
951 ) {
952 assert_eq!(lhs.non_intercept_start, rhs.non_intercept_start);
953 assert_matrix_close(
954 &lhs.projection_coef,
955 &rhs.projection_coef,
956 tol,
957 "projection coefficients",
958 );
959 assert_eq!(
960 lhs.weighted_column_mean.len(),
961 rhs.weighted_column_mean.len()
962 );
963 assert_eq!(lhs.rescale.len(), rhs.rescale.len());
964 for j in 0..lhs.weighted_column_mean.len() {
965 assert!(
966 (lhs.weighted_column_mean[j] - rhs.weighted_column_mean[j]).abs() <= tol,
967 "weighted column mean mismatch at {j}: {} vs {}",
968 lhs.weighted_column_mean[j],
969 rhs.weighted_column_mean[j]
970 );
971 assert!(
972 (lhs.rescale[j] - rhs.rescale[j]).abs() <= tol,
973 "rescale mismatch at {j}: {} vs {}",
974 lhs.rescale[j],
975 rhs.rescale[j]
976 );
977 }
978 }
979
980 #[test]
981 fn scale_deviation_transform_overdetermined() {
982 let n = 1000;
983 let p_primary = 10;
984 let p_noise = 5;
985
986 let mut primary = Array2::<f64>::zeros((n, p_primary));
987 let mut noise = Array2::<f64>::zeros((n, p_noise));
988 for i in 0..n {
989 for j in 0..p_primary {
990 primary[[i, j]] = ((i * 3 + j * 11) as f64 * 0.1).sin();
991 }
992 for j in 0..p_noise {
993 noise[[i, j]] = ((i * 5 + j * 13) as f64 * 0.1).cos();
994 }
995 }
996 noise.column_mut(0).fill(1.0);
997 let weights = Array1::<f64>::ones(n);
998
999 let transform = build_scale_deviation_transform(&primary, &noise, &weights, 1)
1000 .expect("transform should succeed for overdetermined inputs");
1001 let transformed = apply_scale_deviation_transform(&primary, &noise, &transform)
1002 .expect("apply should succeed for overdetermined inputs");
1003
1004 assert_eq!(transform.projection_coef.dim(), (p_primary, p_noise));
1005 assert_eq!(transformed.dim(), (n, p_noise));
1006 assert!(transformed.iter().all(|v| v.is_finite()));
1007 assert!(transformed.column(0).iter().all(|&v| v == 1.0));
1008
1009 let primary_design =
1010 DesignMatrix::Dense(gam_linalg::matrix::DenseDesignMatrix::from(primary.clone()));
1011 let noise_design =
1012 DesignMatrix::Dense(gam_linalg::matrix::DenseDesignMatrix::from(noise.clone()));
1013 let non_intercept_start = infer_non_intercept_start_design(&noise_design, &weights)
1014 .expect("design-native non-intercept detection should succeed");
1015 assert_eq!(non_intercept_start, 1);
1016 let design_transform = build_scale_deviation_transform_design(
1017 &primary_design,
1018 &noise_design,
1019 &weights,
1020 non_intercept_start,
1021 )
1022 .expect("design-native transform should succeed");
1023 let transformed_design =
1024 build_scale_deviation_operator(primary_design, noise_design, &design_transform)
1025 .expect("design-native operator should build")
1026 .to_dense();
1027
1028 assert_eq!(design_transform.projection_coef.dim(), (p_primary, p_noise));
1029 assert_eq!(transformed_design.dim(), transformed.dim());
1030 assert_transform_close(&transform, &design_transform, 1e-10);
1031 assert_matrix_close(
1032 &transformed_design,
1033 &transformed,
1034 1e-8,
1035 "transformed design",
1036 );
1037 }
1038
1039 #[test]
1040 fn scale_deviation_transform_rank_deficient_primary_matches_design_path() {
1041 let n = 384;
1042 let p_primary = 4;
1043 let p_noise = 4;
1044 let mut primary = Array2::<f64>::zeros((n, p_primary));
1045 let mut noise = Array2::<f64>::zeros((n, p_noise));
1046 let mut weights = Array1::<f64>::zeros(n);
1047
1048 for i in 0..n {
1049 let t = i as f64 / n as f64;
1050 let wobble = (17.0 * t).sin();
1051 primary[[i, 0]] = 1.0;
1052 primary[[i, 1]] = t;
1053 primary[[i, 2]] = t + 1e-12 * wobble;
1054 primary[[i, 3]] = 2.0 * t - 1e-12 * wobble;
1055
1056 noise[[i, 0]] = 1.0;
1057 noise[[i, 1]] = 0.7 * t + 0.2 * (9.0 * t).cos();
1058 noise[[i, 2]] = primary[[i, 1]] - primary[[i, 2]] + 0.1 * (13.0 * t).sin();
1059 noise[[i, 3]] = 0.5 * primary[[i, 3]] + 0.3 * (5.0 * t).cos();
1060
1061 weights[i] = if i % 17 == 0 {
1062 0.0
1063 } else {
1064 0.5 + (11.0 * t).sin().abs()
1065 };
1066 }
1067
1068 let transform = build_scale_deviation_transform(&primary, &noise, &weights, 1)
1069 .expect("dense transform should succeed for ill-conditioned primary");
1070 let transformed = apply_scale_deviation_transform(&primary, &noise, &transform)
1071 .expect("dense apply should succeed for ill-conditioned primary");
1072
1073 let primary_design =
1074 DesignMatrix::Dense(gam_linalg::matrix::DenseDesignMatrix::from(primary.clone()));
1075 let noise_design =
1076 DesignMatrix::Dense(gam_linalg::matrix::DenseDesignMatrix::from(noise.clone()));
1077 let non_intercept_start = infer_non_intercept_start_design(&noise_design, &weights)
1078 .expect("design-native non-intercept detection should succeed");
1079 assert_eq!(non_intercept_start, 1);
1080
1081 let design_transform = build_scale_deviation_transform_design(
1082 &primary_design,
1083 &noise_design,
1084 &weights,
1085 non_intercept_start,
1086 )
1087 .expect("design-native transform should succeed for ill-conditioned primary");
1088 let transformed_design =
1089 build_scale_deviation_operator(primary_design, noise_design, &design_transform)
1090 .expect("design-native operator should build for ill-conditioned primary")
1091 .to_dense();
1092
1093 assert_transform_close(&transform, &design_transform, 1e-10);
1094 assert_matrix_close(
1095 &transformed_design,
1096 &transformed,
1097 1e-8,
1098 "ill-conditioned transformed design",
1099 );
1100 }
1101
1102 #[test]
1103 fn choose_scale_projection_ridge_alpha_scales_with_sigma_max() {
1104 let alpha_unit = choose_scale_projection_ridge_alpha(&[1.0, 0.5, 1e-6]);
1108 let expected_unit = SCALE_PROJECTION_REPLAY_RCOND_FLOOR.powi(2);
1109 assert!(alpha_unit > 0.0);
1110 assert!(
1111 (alpha_unit - expected_unit).abs() < 1e-24,
1112 "alpha should be {expected_unit:e} for sigma_max=1, got {alpha_unit}"
1113 );
1114
1115 let alpha_scaled = choose_scale_projection_ridge_alpha(&[100.0, 1.0]);
1116 let expected_scaled = (SCALE_PROJECTION_REPLAY_RCOND_FLOOR * 100.0).powi(2);
1117 assert!(
1118 (alpha_scaled - expected_scaled).abs() < 1e-18,
1119 "alpha should be {expected_scaled:e} for sigma_max=100, got {alpha_scaled}"
1120 );
1121 assert!(
1123 (alpha_scaled / alpha_unit - 1.0e4).abs() < 1e-6,
1124 "alpha should scale as sigma_max^2; got ratio {}",
1125 alpha_scaled / alpha_unit
1126 );
1127
1128 let alpha_floor = choose_scale_projection_ridge_alpha(&[]);
1129 assert_eq!(alpha_floor, 0.0);
1130 }
1131
1132 #[test]
1133 fn ridge_replay_continuous_under_input_sweep() {
1134 let n = 64;
1139 let mut primary = Array2::<f64>::zeros((n, 3));
1140 let mut noise = Array2::<f64>::zeros((n, 2));
1141 let weights = Array1::<f64>::ones(n);
1142 for i in 0..n {
1143 let t = i as f64 / n as f64;
1144 primary[[i, 0]] = 1.0;
1145 primary[[i, 1]] = t;
1146 primary[[i, 2]] = t + 1e-9 * (5.0 * t).sin();
1148 noise[[i, 0]] = 1.0;
1149 noise[[i, 1]] = (0.4 * t).cos();
1150 }
1151
1152 let mut last: Option<f64> = None;
1156 let mut max_step: f64 = 0.0;
1157 for k in 0..50 {
1158 let s = k as f64 / 49.0;
1159 let mut perturbed = noise.clone();
1160 for i in 0..n {
1161 perturbed[[i, 1]] += s;
1162 }
1163 let transform = build_scale_deviation_transform(&primary, &perturbed, &weights, 1)
1164 .expect("ridge transform should succeed under input sweep");
1165 let val = transform.projection_coef[[2, 1]];
1166 if let Some(prev) = last {
1167 let step = (val - prev).abs();
1168 max_step = max_step.max(step);
1169 }
1170 last = Some(val);
1171 }
1172 assert!(
1176 max_step < 0.5,
1177 "replay coefficient sweep should be continuous, got max step {max_step}"
1178 );
1179 }
1180
1181 #[test]
1182 fn ridge_replay_noise_free_is_near_identity() {
1183 let n = 128;
1188 let p_primary = 4;
1189 let p_noise = 3;
1190 let mut primary = Array2::<f64>::zeros((n, p_primary));
1191 let mut noise = Array2::<f64>::zeros((n, p_noise));
1192 let weights = Array1::<f64>::ones(n);
1193 for i in 0..n {
1194 let t = i as f64 / n as f64;
1195 primary[[i, 0]] = 1.0;
1196 primary[[i, 1]] = t;
1197 primary[[i, 2]] = (3.0 * t).sin();
1198 primary[[i, 3]] = (2.0 * t - 0.4).powi(2);
1199 noise[[i, 0]] = 1.0;
1200 noise[[i, 1]] = 0.7 * primary[[i, 1]] - 0.3 * primary[[i, 2]];
1203 noise[[i, 2]] = 0.2 * primary[[i, 3]] + 0.1 * primary[[i, 1]];
1204 }
1205
1206 let transform = build_scale_deviation_transform(&primary, &noise, &weights, 1)
1207 .expect("transform should succeed");
1208 let transformed = apply_scale_deviation_transform(&primary, &noise, &transform)
1209 .expect("apply should succeed");
1210
1211 for i in 0..n {
1213 assert_eq!(transformed[[i, 0]], 1.0);
1214 }
1215 for j in 1..p_noise {
1219 for i in 0..n {
1220 assert!(
1221 transformed[[i, j]].abs() < 1e-6,
1222 "noise-free residual should be near zero at ({i},{j}), got {}",
1223 transformed[[i, j]]
1224 );
1225 }
1226 }
1227 assert!(transform.projection_ridge_alpha > 0.0);
1228 }
1229
1230 #[test]
1231 fn scale_transform_payload_round_trips_alpha() {
1232 let n = 64;
1233 let mut primary = Array2::<f64>::zeros((n, 3));
1234 let mut noise = Array2::<f64>::zeros((n, 2));
1235 let weights = Array1::<f64>::ones(n);
1236 for i in 0..n {
1237 let t = i as f64 / n as f64;
1238 primary[[i, 0]] = 1.0;
1239 primary[[i, 1]] = t;
1240 primary[[i, 2]] = (4.0 * t).cos();
1241 noise[[i, 0]] = 1.0;
1242 noise[[i, 1]] = (2.0 * t).sin();
1243 }
1244 let transform = build_scale_deviation_transform(&primary, &noise, &weights, 1)
1245 .expect("transform should succeed");
1246
1247 let projection: Vec<Vec<f64>> = transform
1248 .projection_coef
1249 .rows()
1250 .into_iter()
1251 .map(|row| row.to_vec())
1252 .collect();
1253 let center = transform.weighted_column_mean.to_vec();
1254 let scale = transform.rescale.to_vec();
1255 let restored = scale_transform_from_payload(
1256 &Some(projection),
1257 &Some(center),
1258 &Some(scale),
1259 Some(transform.non_intercept_start),
1260 Some(transform.projection_ridge_alpha),
1261 )
1262 .expect("payload round-trip should succeed")
1263 .expect("payload should produce a transform");
1264 assert_eq!(
1265 restored.projection_ridge_alpha, transform.projection_ridge_alpha,
1266 "alpha must round-trip exactly through payload serialization"
1267 );
1268 }
1269}