Skip to main content

gam_model_kernels/
scale_design.rs

1use gam_linalg::faer_ndarray::{FaerSvd, fast_ab};
2use gam_linalg::matrix::{DenseDesignMatrix, DenseDesignOperator, DesignMatrix, LinearOperator};
3use ndarray::{Array1, Array2, ArrayViewMut2, s};
4use std::ops::Range;
5use std::sync::Arc;
6
7/// Typed error variants for the scale-deviation design module.
8///
9/// External-facing helpers continue to return `Result<_, String>`; this enum
10/// is materialized internally and converted at the boundary so that error
11/// text remains byte-identical to the previous `format!` output.
12#[derive(Debug, Clone)]
13pub enum ScaleDesignError {
14    /// Weight vector contains an invalid entry (NaN/inf, negative, or sums
15    /// to a non-positive / non-finite total).
16    InvalidWeights { reason: String },
17    /// Dimensions of the supplied matrices/vectors are inconsistent.
18    IncompatibleDimensions { reason: String },
19    /// Input value is not finite where finiteness is required (e.g. saved
20    /// projection cutoff alpha).
21    NonFiniteInput { reason: String },
22    /// Saved payload is partially populated or the projection is degenerate
23    /// (e.g. zero rows with non-empty columns).
24    DegenerateDesign { reason: String },
25    /// Row materialization from an underlying `DesignMatrix` failed.
26    RowMaterializationFailed { reason: String },
27    /// Thin SVD of the weighted primary design failed or produced no
28    /// singular vectors.
29    SvdFailed { reason: String },
30}
31
32impl_reason_error_boilerplate! {
33    ScaleDesignError {
34        InvalidWeights,
35        IncompatibleDimensions,
36        NonFiniteInput,
37        DegenerateDesign,
38        RowMaterializationFailed,
39        SvdFailed,
40    }
41}
42
43const COLUMN_TOL: f64 = 1e-12;
44const SCALE_DESIGN_TARGET_CHUNK_BYTES: usize = 8 * 1024 * 1024;
45// Numerical conditioning floor for the SVD truncation tolerance: we drop any
46// singular direction below `RCOND_FLOOR * sigma_max`, which is the standard
47// machine-precision boundary for considering a direction resolvable. Above
48// this floor, the replay solve is unbiased least squares (no Tikhonov
49// damping), so noise in the primary span is recovered exactly. This is the
50// primary safety net.
51const SCALE_PROJECTION_REPLAY_RCOND_FLOOR: f64 = 1e-8;
52// Optional tighter cap on coefficient amplification, used only when the
53// design is so well-conditioned that even the worst retained direction would
54// not amplify a unit prediction row beyond this multiple. For natural smooth
55// bases (cond ≈ 100–1000) this cap is dominated by the rcond floor and has no
56// effect; it kicks in only for nearly-orthogonal designs where one could
57// otherwise tighten the cutoff without losing real signal. Setting this much
58// smaller than `1 / RCOND_FLOOR` would discard real signal from moderately
59// conditioned bases and is intentionally avoided.
60const SCALE_PROJECTION_LEVERAGE_AMPLIFICATION: f64 = 1.0e8;
61// Above this many materialized entries (rows × noise columns) the scale-deviation
62// operator routes its normal-equation solve through matrix-free PCG instead of
63// forming a dense `XᵀWX`. The dense path costs `O(n · p²)` time and `O(p²)`
64// memory; once the explicit operator footprint reaches ~10⁶ doubles (~8 MiB,
65// matching `SCALE_DESIGN_TARGET_CHUNK_BYTES`) the chunked matrix-free path is
66// the cheaper, more cache-friendly route.
67const SCALE_OPERATOR_MATRIX_FREE_PCG_THRESHOLD: usize = 1_000_000;
68
69#[derive(Clone, Debug)]
70pub struct ScaleDeviationTransform {
71    pub projection_coef: Array2<f64>,
72    pub weighted_column_mean: Array1<f64>,
73    pub rescale: Array1<f64>,
74    pub non_intercept_start: usize,
75    /// Squared SVD truncation cutoff used when fitting `projection_coef`.
76    /// Stored so prediction-time replay is reproducible without re-deriving
77    /// the cutoff from heuristics.
78    pub projection_ridge_alpha: f64,
79}
80
81impl ScaleDeviationTransform {
82    /// Identity (no-op) reparameterization: zero projection, zero centering,
83    /// unit rescale. [`build_scale_deviation_operator`] with this transform
84    /// returns the raw scale design verbatim, and the saved-payload round-trip
85    /// replays the same identity at prediction time.
86    ///
87    /// A location and a scale predictor remain SEPARATELY identifiable even when
88    /// they share a covariate basis: they enter the likelihood through different
89    /// sufficient statistics (the standardized residual versus its square / the
90    /// log-scale), so residualizing the scale design against the location design
91    /// — replacing `X_σ` with `(I − P_{X_μ}) X_σ` — imposes a spurious
92    /// constraint and erases real heteroscedastic signal whenever the two blocks
93    /// overlap. The Gaussian location-scale path already keeps its log-σ design
94    /// un-residualized (`identified_gaussian_log_sigma_design`); this constructor
95    /// lets the survival location-scale path do the same while preserving the
96    /// transform plumbing (payload serialization, prediction-time replay).
97    pub fn identity(p_primary: usize, p_noise: usize, non_intercept_start: usize) -> Self {
98        ScaleDeviationTransform {
99            projection_coef: Array2::<f64>::zeros((p_primary, p_noise)),
100            weighted_column_mean: Array1::<f64>::zeros(p_noise),
101            rescale: Array1::<f64>::ones(p_noise),
102            non_intercept_start,
103            projection_ridge_alpha: 0.0,
104        }
105    }
106}
107
108/// Build a [`ScaleDeviationTransform`] from saved projection metadata.
109///
110/// Returns `Ok(None)` only when the payload is completely absent; partial
111/// payloads are invalid because prediction cannot replay the fitted scale
112/// reparameterization unambiguously.
113pub fn scale_transform_from_payload(
114    projection: &Option<Vec<Vec<f64>>>,
115    center: &Option<Vec<f64>>,
116    scale: &Option<Vec<f64>>,
117    non_intercept_start: Option<usize>,
118    projection_ridge_alpha: Option<f64>,
119) -> Result<Option<ScaleDeviationTransform>, String> {
120    scale_transform_from_payload_typed(
121        projection,
122        center,
123        scale,
124        non_intercept_start,
125        projection_ridge_alpha,
126    )
127    .map_err(|e| e.to_string())
128}
129
130fn scale_transform_from_payload_typed(
131    projection: &Option<Vec<Vec<f64>>>,
132    center: &Option<Vec<f64>>,
133    scale: &Option<Vec<f64>>,
134    non_intercept_start: Option<usize>,
135    projection_ridge_alpha: Option<f64>,
136) -> Result<Option<ScaleDeviationTransform>, ScaleDesignError> {
137    match (projection, center, scale, non_intercept_start) {
138        (None, None, None, None) => Ok(None),
139        (Some(projection), Some(center), Some(scale), Some(non_intercept_start)) => {
140            let rows = projection.len();
141            let cols = center.len();
142            if cols != scale.len() {
143                return Err(ScaleDesignError::IncompatibleDimensions {
144                    reason: "saved scale transform center/scale length mismatch".to_string(),
145                });
146            }
147            if rows == 0 && cols > 0 {
148                return Err(ScaleDesignError::DegenerateDesign {
149                    reason: "saved scale transform projection has zero rows".to_string(),
150                });
151            }
152            let mut projection_coef = Array2::<f64>::zeros((rows, cols));
153            for (i, row) in projection.iter().enumerate() {
154                if row.len() != cols {
155                    return Err(ScaleDesignError::IncompatibleDimensions {
156                        reason: "saved scale transform projection width mismatch".to_string(),
157                    });
158                }
159                for (j, &value) in row.iter().enumerate() {
160                    projection_coef[[i, j]] = value;
161                }
162            }
163            let Some(projection_ridge_alpha) = projection_ridge_alpha else {
164                return Err(ScaleDesignError::DegenerateDesign {
165                    reason:
166                        "saved scale transform payload is missing projection_ridge_alpha; refit"
167                            .to_string(),
168                });
169            };
170            if !projection_ridge_alpha.is_finite() || projection_ridge_alpha < 0.0 {
171                return Err(ScaleDesignError::NonFiniteInput {
172                    reason: format!(
173                        "saved scale transform projection_ridge_alpha must be finite and non-negative, got {projection_ridge_alpha}"
174                    ),
175                });
176            }
177            Ok(Some(ScaleDeviationTransform {
178                projection_coef,
179                weighted_column_mean: Array1::from_vec(center.clone()),
180                rescale: Array1::from_vec(scale.clone()),
181                non_intercept_start,
182                projection_ridge_alpha,
183            }))
184        }
185        _ => Err(ScaleDesignError::DegenerateDesign {
186            reason: "saved scale transform payload is only partially populated; refit".to_string(),
187        }),
188    }
189}
190
191#[derive(Clone, Copy)]
192enum ScaleDesignMatrixRef<'a> {
193    Dense(&'a Array2<f64>),
194    Design(&'a DesignMatrix),
195}
196
197impl ScaleDesignMatrixRef<'_> {
198    #[inline]
199    fn nrows(self) -> usize {
200        match self {
201            Self::Dense(matrix) => matrix.nrows(),
202            Self::Design(matrix) => matrix.nrows(),
203        }
204    }
205
206    #[inline]
207    fn ncols(self) -> usize {
208        match self {
209            Self::Dense(matrix) => matrix.ncols(),
210            Self::Design(matrix) => matrix.ncols(),
211        }
212    }
213
214    fn row_chunk(self, rows: Range<usize>) -> Result<Array2<f64>, ScaleDesignError> {
215        match self {
216            Self::Dense(matrix) => Ok(matrix.slice(s![rows, ..]).to_owned()),
217            Self::Design(matrix) => {
218                matrix
219                    .try_row_chunk(rows)
220                    .map_err(|e| ScaleDesignError::RowMaterializationFailed {
221                        reason: format!("scale deviation row materialization failed: {e}"),
222                    })
223            }
224        }
225    }
226}
227
228pub fn infer_non_intercept_start(design: &Array2<f64>, weights: &Array1<f64>) -> usize {
229    infer_non_intercept_start_impl(
230        ScaleDesignMatrixRef::Dense(design),
231        weights,
232        "weighted column stats row mismatch".to_string(),
233    )
234    .unwrap_or(0)
235}
236
237fn dim_err(reason: impl Into<String>) -> ScaleDesignError {
238    ScaleDesignError::IncompatibleDimensions {
239        reason: reason.into(),
240    }
241}
242
243pub fn build_scale_deviation_transform(
244    primary_design: &Array2<f64>,
245    noise_design: &Array2<f64>,
246    weights: &Array1<f64>,
247    non_intercept_start: usize,
248) -> Result<ScaleDeviationTransform, String> {
249    build_scale_deviation_transform_impl(
250        ScaleDesignMatrixRef::Dense(primary_design),
251        ScaleDesignMatrixRef::Dense(noise_design),
252        weights,
253        non_intercept_start,
254        "scale deviation transform row mismatch",
255    )
256    .map_err(|e| e.to_string())
257}
258
259pub fn apply_scale_deviation_transform(
260    primary_design: &Array2<f64>,
261    rawnoise_design: &Array2<f64>,
262    transform: &ScaleDeviationTransform,
263) -> Result<Array2<f64>, String> {
264    apply_scale_deviation_transform_typed(primary_design, rawnoise_design, transform)
265        .map_err(|e| e.to_string())
266}
267
268fn apply_scale_deviation_transform_typed(
269    primary_design: &Array2<f64>,
270    rawnoise_design: &Array2<f64>,
271    transform: &ScaleDeviationTransform,
272) -> Result<Array2<f64>, ScaleDesignError> {
273    if primary_design.nrows() != rawnoise_design.nrows() {
274        return Err(dim_err("scale deviation apply row mismatch"));
275    }
276    if primary_design.ncols() != transform.projection_coef.nrows()
277        || rawnoise_design.ncols() != transform.projection_coef.ncols()
278    {
279        return Err(dim_err("scale deviation apply column mismatch"));
280    }
281    let n = rawnoise_design.nrows();
282    let p_primary = primary_design.ncols();
283    let p_noise = rawnoise_design.ncols();
284    let chunk_rows = scale_design_row_chunk_size(n, p_primary.max(p_noise));
285    let mut out = Array2::<f64>::zeros((n, p_noise));
286    for start in (0..n).step_by(chunk_rows) {
287        let end = (start + chunk_rows).min(n);
288        let primary_chunk = primary_design.slice(s![start..end, ..]).to_owned();
289        let noise_chunk = rawnoise_design.slice(s![start..end, ..]).to_owned();
290        let chunk = apply_scale_deviation_reparam_chunk(&primary_chunk, &noise_chunk, transform);
291        out.slice_mut(s![start..end, ..]).assign(&chunk);
292    }
293    Ok(out)
294}
295
296#[derive(Clone)]
297struct ScaleDeviationOperator {
298    primary_design: DesignMatrix,
299    rawnoise_design: DesignMatrix,
300    transform: ScaleDeviationTransform,
301    chunk_rows: usize,
302}
303
304impl ScaleDeviationOperator {
305    fn row_chunk(&self, rows: Range<usize>) -> Result<Array2<f64>, ScaleDesignError> {
306        let primary_chunk = self
307            .primary_design
308            .try_row_chunk(rows.clone())
309            .map_err(|e| ScaleDesignError::RowMaterializationFailed {
310                reason: format!("scale deviation operator primary chunk: {e}"),
311            })?;
312        let noise_chunk = self.rawnoise_design.try_row_chunk(rows).map_err(|e| {
313            ScaleDesignError::RowMaterializationFailed {
314                reason: format!("scale deviation operator noise chunk: {e}"),
315            }
316        })?;
317        Ok(apply_scale_deviation_reparam_chunk(
318            &primary_chunk,
319            &noise_chunk,
320            &self.transform,
321        ))
322    }
323}
324
325impl LinearOperator for ScaleDeviationOperator {
326    fn nrows(&self) -> usize {
327        self.rawnoise_design.nrows()
328    }
329
330    fn ncols(&self) -> usize {
331        self.rawnoise_design.ncols()
332    }
333
334    fn apply(&self, vector: &Array1<f64>) -> Array1<f64> {
335        assert_eq!(vector.len(), self.ncols());
336        let n = self.nrows();
337        let mut out = Array1::<f64>::zeros(n);
338        for start in (0..n).step_by(self.chunk_rows) {
339            let end = (start + self.chunk_rows).min(n);
340            let chunk = self
341                .row_chunk(start..end)
342                .expect("scale deviation operator row chunk failed");
343            out.slice_mut(s![start..end]).assign(&chunk.dot(vector));
344        }
345        out
346    }
347
348    fn apply_transpose(&self, vector: &Array1<f64>) -> Array1<f64> {
349        assert_eq!(vector.len(), self.nrows());
350        let n = self.nrows();
351        let p = self.ncols();
352        let mut out = Array1::<f64>::zeros(p);
353        for start in (0..n).step_by(self.chunk_rows) {
354            let end = (start + self.chunk_rows).min(n);
355            let chunk = self
356                .row_chunk(start..end)
357                .expect("scale deviation operator row chunk failed");
358            out += &chunk.t().dot(&vector.slice(s![start..end]).to_owned());
359        }
360        out
361    }
362
363    fn diag_xtw_x(&self, weights: &Array1<f64>) -> Result<Array2<f64>, String> {
364        if weights.len() != self.nrows() {
365            return Err(dim_err(format!(
366                "scale deviation operator XtWX weight mismatch: weights={}, rows={}",
367                weights.len(),
368                self.nrows()
369            ))
370            .to_string());
371        }
372        let n = self.nrows();
373        let p = self.ncols();
374        let mut out = Array2::<f64>::zeros((p, p));
375        for start in (0..n).step_by(self.chunk_rows) {
376            let end = (start + self.chunk_rows).min(n);
377            let chunk = self.row_chunk(start..end).map_err(|e| e.to_string())?;
378            for local in 0..chunk.nrows() {
379                let w = weights[start + local].max(0.0);
380                if w == 0.0 {
381                    continue;
382                }
383                for a in 0..p {
384                    let xa = chunk[[local, a]];
385                    for b in a..p {
386                        let value = w * xa * chunk[[local, b]];
387                        out[[a, b]] += value;
388                        if a != b {
389                            out[[b, a]] += value;
390                        }
391                    }
392                }
393            }
394        }
395        Ok(out)
396    }
397
398    fn uses_matrix_free_pcg(&self) -> bool {
399        self.primary_design
400            .nrows()
401            .saturating_mul(self.rawnoise_design.ncols())
402            > SCALE_OPERATOR_MATRIX_FREE_PCG_THRESHOLD
403    }
404}
405
406impl DenseDesignOperator for ScaleDeviationOperator {
407    fn row_chunk_into(
408        &self,
409        rows: Range<usize>,
410        mut out: ArrayViewMut2<'_, f64>,
411    ) -> Result<(), gam_runtime::resource::MatrixMaterializationError> {
412        let chunk = self.row_chunk(rows).map_err(|err| {
413            gam_runtime::resource::MatrixMaterializationError::RowMaterializationFailed {
414                context: "ScaleDeviationOperator::row_chunk_into",
415                reason: err.to_string(),
416            }
417        })?;
418        out.assign(&chunk);
419        Ok(())
420    }
421
422    fn to_dense(&self) -> Array2<f64> {
423        let n = self.nrows();
424        let p = self.ncols();
425        let mut out = Array2::<f64>::zeros((n, p));
426        for start in (0..n).step_by(self.chunk_rows) {
427            let end = (start + self.chunk_rows).min(n);
428            let chunk = self
429                .row_chunk(start..end)
430                .expect("scale deviation operator row chunk failed");
431            out.slice_mut(s![start..end, ..]).assign(&chunk);
432        }
433        out
434    }
435}
436
437#[derive(Debug)]
438struct WeightedColumnStats {
439    weighted_sum: Array1<f64>,
440    weighted_sum_sq: Array1<f64>,
441    total_weight: f64,
442}
443
444fn validate_scale_weights(weights: &Array1<f64>) -> Result<f64, ScaleDesignError> {
445    let mut total_weight = 0.0;
446    for (idx, &w) in weights.iter().enumerate() {
447        if !w.is_finite() {
448            return Err(ScaleDesignError::NonFiniteInput {
449                reason: format!("scale deviation weight {idx} is not finite"),
450            });
451        }
452        if w < 0.0 {
453            return Err(ScaleDesignError::InvalidWeights {
454                reason: format!(
455                    "scale deviation requires non-negative weights, got {w} at index {idx}"
456                ),
457            });
458        }
459        total_weight += w;
460    }
461    if !total_weight.is_finite() || total_weight <= 0.0 {
462        return Err(ScaleDesignError::InvalidWeights {
463            reason: "scale deviation requires positive finite total weight".to_string(),
464        });
465    }
466    Ok(total_weight)
467}
468
469fn scale_design_row_chunk_size(nrows: usize, max_cols: usize) -> usize {
470    (SCALE_DESIGN_TARGET_CHUNK_BYTES / (max_cols.max(1) * std::mem::size_of::<f64>()))
471        .max(1)
472        .min(nrows.max(1))
473}
474
475fn weighted_column_stats(
476    design: ScaleDesignMatrixRef<'_>,
477    weights: &Array1<f64>,
478    row_mismatch_error: String,
479) -> Result<WeightedColumnStats, ScaleDesignError> {
480    if design.nrows() != weights.len() {
481        return Err(dim_err(row_mismatch_error));
482    }
483    let total_weight = validate_scale_weights(weights)?;
484    let p = design.ncols();
485    let mut weighted_sum = Array1::<f64>::zeros(p);
486    let mut weighted_sum_sq = Array1::<f64>::zeros(p);
487    let chunk_rows = scale_design_row_chunk_size(design.nrows(), p);
488    for start in (0..design.nrows()).step_by(chunk_rows) {
489        let end = (start + chunk_rows).min(design.nrows());
490        let chunk = design.row_chunk(start..end)?;
491        for local in 0..(end - start) {
492            let w = weights[start + local];
493            if w == 0.0 {
494                continue;
495            }
496            for j in 0..p {
497                let x = chunk[[local, j]];
498                weighted_sum[j] += w * x;
499                weighted_sum_sq[j] += w * x * x;
500            }
501        }
502    }
503    Ok(WeightedColumnStats {
504        weighted_sum,
505        weighted_sum_sq,
506        total_weight,
507    })
508}
509
510fn infer_non_intercept_start_impl(
511    design: ScaleDesignMatrixRef<'_>,
512    weights: &Array1<f64>,
513    row_mismatch_error: String,
514) -> Result<usize, ScaleDesignError> {
515    let stats = weighted_column_stats(design, weights, row_mismatch_error)?;
516    let mut end = 0;
517    for j in 0..stats.weighted_sum.len() {
518        let centered_ss = stats.weighted_sum_sq[j]
519            - stats.weighted_sum[j] * stats.weighted_sum[j] / stats.total_weight;
520        if centered_ss <= COLUMN_TOL {
521            end = j + 1;
522        } else {
523            break;
524        }
525    }
526    Ok(end)
527}
528
529fn build_weighted_primary_design(
530    primary_design: ScaleDesignMatrixRef<'_>,
531    sqrtw: &Array1<f64>,
532    chunk_rows: usize,
533) -> Result<Array2<f64>, ScaleDesignError> {
534    let n = primary_design.nrows();
535    let p_primary = primary_design.ncols();
536    let mut wx = Array2::<f64>::zeros((n, p_primary));
537    for start in (0..n).step_by(chunk_rows) {
538        let end = (start + chunk_rows).min(n);
539        let x_chunk = primary_design.row_chunk(start..end)?;
540        for local in 0..(end - start) {
541            let sw = sqrtw[start + local];
542            for col in 0..p_primary {
543                wx[[start + local, col]] = sw * x_chunk[[local, col]];
544            }
545        }
546    }
547    Ok(wx)
548}
549
550/// Pick the squared singular-value cutoff for the replay solve.
551///
552/// Retained directions use the exact inverse `1 / sigma_k`; directions at or
553/// below `sqrt(alpha)` are dropped. We want the worst-case prediction-row
554/// leverage amplification — a unit-norm new row transformed by the saved
555/// coefficients — to be at most `SCALE_PROJECTION_LEVERAGE_AMPLIFICATION`
556/// times what a sigma_max-scale direction sees in the un-regularized solve.
557/// The rcond floor supplies the minimum cutoff for numerical conditioning.
558fn choose_scale_projection_ridge_alpha(singular: &[f64]) -> f64 {
559    if singular.is_empty() {
560        return 0.0;
561    }
562    let sigma_max = singular.iter().copied().fold(0.0_f64, f64::max);
563    if !sigma_max.is_finite() || sigma_max <= 0.0 {
564        return 0.0;
565    }
566    let derived_tol = sigma_max / SCALE_PROJECTION_LEVERAGE_AMPLIFICATION;
567    let truncation_tol = derived_tol.max(SCALE_PROJECTION_REPLAY_RCOND_FLOOR * sigma_max);
568    truncation_tol * truncation_tol
569}
570
571fn solve_scale_projection(
572    primary_design: ScaleDesignMatrixRef<'_>,
573    noise_design: ScaleDesignMatrixRef<'_>,
574    weights: &Array1<f64>,
575    first_active: usize,
576    chunk_rows: usize,
577) -> Result<(Array2<f64>, f64), ScaleDesignError> {
578    let n = primary_design.nrows();
579    let p_primary = primary_design.ncols();
580    let p_noise = noise_design.ncols();
581    let mut projection_coef = Array2::<f64>::zeros((p_primary, p_noise));
582    let active_cols = p_noise.saturating_sub(first_active);
583
584    if active_cols == 0 || p_primary == 0 {
585        return Ok((projection_coef, 0.0));
586    }
587
588    let sqrtw = weights.mapv(f64::sqrt);
589    let wx = build_weighted_primary_design(primary_design, &sqrtw, chunk_rows)?;
590    // Thin SVD of W^{1/2} X_primary: replay reduces to V * diag(filter) * U^T
591    // applied to the weighted noise RHS. Retained singular directions use the
592    // exact inverse; unresolved directions are dropped by the cutoff below.
593    let (u_opt, singular, vt_opt) =
594        wx.svd(true, true)
595            .map_err(|e| ScaleDesignError::SvdFailed {
596                reason: format!("scale projection SVD failed: {e:?}"),
597            })?;
598    let (Some(u), Some(vt)) = (u_opt, vt_opt) else {
599        return Err(ScaleDesignError::SvdFailed {
600            reason: "scale projection SVD did not return singular vectors".to_string(),
601        });
602    };
603    let alpha = choose_scale_projection_ridge_alpha(singular.as_slice().unwrap_or(&[]));
604    let rank = singular.len();
605    if rank == 0 {
606        return Ok((projection_coef, alpha));
607    }
608    // Truncated SVD with leverage-bound cutoff: directions resolved well
609    // enough to keep coefficient amplification under
610    // SCALE_PROJECTION_LEVERAGE_AMPLIFICATION are inverted exactly (no
611    // damping on the dominant components), and weaker directions are
612    // dropped. The primary design is fixed across any single replay, so no
613    // threshold-crossings occur within a call: the projection is a linear
614    // function of the noise RHS, which is the continuity property the audit
615    // asked for. The discarded singular value floor sqrt(alpha) doubles as
616    // the recovered-coefficient leverage cap.
617    let cutoff = alpha.sqrt();
618    let mut filter = Array1::<f64>::zeros(rank);
619    for k in 0..rank {
620        let s = singular[k];
621        filter[k] = if s > cutoff && s > 0.0 { 1.0 / s } else { 0.0 };
622    }
623
624    let chunk_cols = (SCALE_DESIGN_TARGET_CHUNK_BYTES / (n.max(1) * std::mem::size_of::<f64>()))
625        .max(1)
626        .min(active_cols);
627
628    for chunk_start in (0..active_cols).step_by(chunk_cols) {
629        let width = (active_cols - chunk_start).min(chunk_cols);
630        let mut rhs = Array2::<f64>::zeros((n, width));
631        for start in (0..n).step_by(chunk_rows) {
632            let end = (start + chunk_rows).min(n);
633            let noise_chunk = noise_design.row_chunk(start..end)?;
634            for local in 0..(end - start) {
635                let sw = sqrtw[start + local];
636                for col in 0..width {
637                    rhs[[start + local, col]] =
638                        sw * noise_chunk[[local, first_active + chunk_start + col]];
639                }
640            }
641        }
642
643        // U^T (rank x n) * rhs (n x width) -> (rank x width)
644        let mut t = u.t().dot(&rhs);
645        // Apply filter rowwise: t_k *= 1 / sigma_k for retained directions.
646        for k in 0..rank {
647            let f = filter[k];
648            for col in 0..width {
649                t[[k, col]] *= f;
650            }
651        }
652        // V (p_primary x rank) * t (rank x width) -> (p_primary x width).
653        // vt has shape (rank, p_primary), so V = vt^T.
654        let block = vt.t().dot(&t);
655        for col in 0..width {
656            for row in 0..p_primary {
657                projection_coef[[row, first_active + chunk_start + col]] = block[[row, col]];
658            }
659        }
660    }
661
662    Ok((projection_coef, alpha))
663}
664
665fn apply_projection_chunk(
666    primary_chunk: &Array2<f64>,
667    projection_coef: &Array2<f64>,
668    first_active: usize,
669) -> Array2<f64> {
670    if first_active >= projection_coef.ncols() {
671        Array2::<f64>::zeros((primary_chunk.nrows(), 0))
672    } else {
673        fast_ab(
674            primary_chunk,
675            &projection_coef.slice(s![.., first_active..]).to_owned(),
676        )
677    }
678}
679
680fn build_scale_deviation_transform_impl(
681    primary_design: ScaleDesignMatrixRef<'_>,
682    noise_design: ScaleDesignMatrixRef<'_>,
683    weights: &Array1<f64>,
684    non_intercept_start: usize,
685    row_mismatch_error: &str,
686) -> Result<ScaleDeviationTransform, ScaleDesignError> {
687    if primary_design.nrows() != noise_design.nrows() || weights.len() != noise_design.nrows() {
688        return Err(dim_err(row_mismatch_error.to_string()));
689    }
690    validate_scale_weights(weights)?;
691
692    let n = primary_design.nrows();
693    let p_primary = primary_design.ncols();
694    let p_noise = noise_design.ncols();
695    let first_active = non_intercept_start.min(p_noise);
696    let chunk_rows = scale_design_row_chunk_size(n, p_primary.max(p_noise));
697    let (projection_coef, projection_ridge_alpha) = solve_scale_projection(
698        primary_design,
699        noise_design,
700        weights,
701        first_active,
702        chunk_rows,
703    )?;
704    let mut weighted_column_mean = Array1::<f64>::zeros(p_noise);
705    let mut rescale = Array1::<f64>::ones(p_noise);
706    let active_cols = p_noise - first_active;
707
708    if active_cols > 0 {
709        let projection_only_transform = ScaleDeviationTransform {
710            projection_coef: projection_coef.clone(),
711            weighted_column_mean: Array1::<f64>::zeros(p_noise),
712            rescale: Array1::<f64>::ones(p_noise),
713            non_intercept_start,
714            projection_ridge_alpha,
715        };
716        let mut w_sum = 0.0;
717        let mut w_resid_sum = Array1::<f64>::zeros(active_cols);
718        let mut w_noise_sum = Array1::<f64>::zeros(active_cols);
719
720        for start in (0..n).step_by(chunk_rows) {
721            let end = (start + chunk_rows).min(n);
722            let x_chunk = primary_design.row_chunk(start..end)?;
723            let noise_chunk = noise_design.row_chunk(start..end)?;
724            let resid_chunk = apply_scale_deviation_reparam_chunk(
725                &x_chunk,
726                &noise_chunk,
727                &projection_only_transform,
728            );
729            for local in 0..(end - start) {
730                let w = weights[start + local];
731                if w == 0.0 {
732                    continue;
733                }
734                w_sum += w;
735                for jj in 0..active_cols {
736                    let nij = noise_chunk[[local, first_active + jj]];
737                    w_noise_sum[jj] += w * nij;
738                    w_resid_sum[jj] += w * resid_chunk[[local, first_active + jj]];
739                }
740            }
741        }
742
743        if !w_sum.is_finite() || w_sum <= 0.0 {
744            return Err(ScaleDesignError::InvalidWeights {
745                reason: "scale deviation requires positive finite total weight".to_string(),
746            });
747        }
748
749        let resid_center = w_resid_sum.mapv(|sum| sum / w_sum);
750        let noise_mean = w_noise_sum.mapv(|sum| sum / w_sum);
751        let mut orig_css = Array1::<f64>::zeros(active_cols);
752        let mut resid_css = Array1::<f64>::zeros(active_cols);
753
754        for start in (0..n).step_by(chunk_rows) {
755            let end = (start + chunk_rows).min(n);
756            let x_chunk = primary_design.row_chunk(start..end)?;
757            let noise_chunk = noise_design.row_chunk(start..end)?;
758            let resid_chunk = apply_scale_deviation_reparam_chunk(
759                &x_chunk,
760                &noise_chunk,
761                &projection_only_transform,
762            );
763            for local in 0..(end - start) {
764                let w = weights[start + local];
765                if w == 0.0 {
766                    continue;
767                }
768                for jj in 0..active_cols {
769                    let nij = noise_chunk[[local, first_active + jj]];
770                    let d_orig = nij - noise_mean[jj];
771                    orig_css[jj] += w * d_orig * d_orig;
772                    let d_resid = resid_chunk[[local, first_active + jj]] - resid_center[jj];
773                    resid_css[jj] += w * d_resid * d_resid;
774                }
775            }
776        }
777
778        for jj in 0..active_cols {
779            let j = first_active + jj;
780            let scale = if resid_css[jj].is_finite()
781                && resid_css[jj] > COLUMN_TOL
782                && orig_css[jj].is_finite()
783                && orig_css[jj] > COLUMN_TOL
784            {
785                (orig_css[jj] / resid_css[jj]).sqrt()
786            } else {
787                1.0
788            };
789            weighted_column_mean[j] = resid_center[jj];
790            rescale[j] = scale;
791        }
792    }
793
794    Ok(ScaleDeviationTransform {
795        projection_coef,
796        weighted_column_mean,
797        rescale,
798        non_intercept_start,
799        projection_ridge_alpha,
800    })
801}
802
803pub fn infer_non_intercept_start_design(
804    design: &DesignMatrix,
805    weights: &Array1<f64>,
806) -> Result<usize, String> {
807    infer_non_intercept_start_impl(
808        ScaleDesignMatrixRef::Design(design),
809        weights,
810        format!(
811            "weighted column stats row mismatch: design has {} rows, weights have {} entries",
812            design.nrows(),
813            weights.len()
814        ),
815    )
816    .map_err(|e| e.to_string())
817}
818
819pub fn build_scale_deviation_transform_design(
820    primary_design: &DesignMatrix,
821    noise_design: &DesignMatrix,
822    weights: &Array1<f64>,
823    non_intercept_start: usize,
824) -> Result<ScaleDeviationTransform, String> {
825    build_scale_deviation_transform_impl(
826        ScaleDesignMatrixRef::Design(primary_design),
827        ScaleDesignMatrixRef::Design(noise_design),
828        weights,
829        non_intercept_start,
830        "scale deviation transform design row mismatch",
831    )
832    .map_err(|e| e.to_string())
833}
834
835/// Apply the scale-deviation reparameterisation to a chunk of rows.
836///
837/// Instead of embedding the projection coefficients into a large augmented
838/// matrix (which changes FP operation order relative to the canonical
839/// `apply_projection_chunk`), we compute the projection via the shared
840/// helper and then fold in rescaling and centering explicitly.  This
841/// guarantees bit-identical projection arithmetic on both paths.
842fn apply_scale_deviation_reparam_chunk(
843    primary_chunk: &Array2<f64>,
844    noise_chunk: &Array2<f64>,
845    transform: &ScaleDeviationTransform,
846) -> Array2<f64> {
847    let rows = noise_chunk.nrows();
848    let p_noise = noise_chunk.ncols();
849    let first_active = transform.non_intercept_start.min(p_noise);
850    let mut out = Array2::<f64>::zeros((rows, p_noise));
851
852    // Pass-through columns (intercept-like) are copied verbatim.
853    for j in 0..first_active {
854        for i in 0..rows {
855            out[[i, j]] = noise_chunk[[i, j]];
856        }
857    }
858
859    // Active columns: residual = noise - projection, then center & rescale.
860    if first_active < p_noise {
861        let fitted =
862            apply_projection_chunk(primary_chunk, &transform.projection_coef, first_active);
863        for j in first_active..p_noise {
864            let jj = j - first_active;
865            let scale = transform.rescale[j];
866            let center = transform.weighted_column_mean[j];
867            for i in 0..rows {
868                out[[i, j]] = (noise_chunk[[i, j]] - fitted[[i, jj]] - center) * scale;
869            }
870        }
871    }
872
873    out
874}
875
876pub fn build_scale_deviation_operator(
877    primary_design: DesignMatrix,
878    rawnoise_design: DesignMatrix,
879    transform: &ScaleDeviationTransform,
880) -> Result<DesignMatrix, String> {
881    build_scale_deviation_operator_typed(primary_design, rawnoise_design, transform)
882        .map_err(|e| e.to_string())
883}
884
885fn build_scale_deviation_operator_typed(
886    primary_design: DesignMatrix,
887    rawnoise_design: DesignMatrix,
888    transform: &ScaleDeviationTransform,
889) -> Result<DesignMatrix, ScaleDesignError> {
890    if primary_design.nrows() != rawnoise_design.nrows() {
891        return Err(dim_err(format!(
892            "scale deviation operator row mismatch: primary rows={}, noise rows={}",
893            primary_design.nrows(),
894            rawnoise_design.nrows()
895        )));
896    }
897    if primary_design.ncols() != transform.projection_coef.nrows()
898        || rawnoise_design.ncols() != transform.projection_coef.ncols()
899    {
900        return Err(dim_err(format!(
901            "scale deviation operator column mismatch: primary cols={}, noise cols={}, transform is {}x{}",
902            primary_design.ncols(),
903            rawnoise_design.ncols(),
904            transform.projection_coef.nrows(),
905            transform.projection_coef.ncols()
906        )));
907    }
908    let n = rawnoise_design.nrows();
909    let p_primary = primary_design.ncols();
910    let p_noise = rawnoise_design.ncols();
911    let chunk_rows = scale_design_row_chunk_size(n, p_primary.max(p_noise));
912    Ok(DesignMatrix::Dense(DenseDesignMatrix::from(Arc::new(
913        ScaleDeviationOperator {
914            primary_design,
915            rawnoise_design,
916            transform: transform.clone(),
917            chunk_rows,
918        },
919    ))))
920}
921
922#[cfg(test)]
923mod tests {
924    use super::*;
925    use gam_linalg::matrix::DesignMatrix;
926
927    fn assert_matrix_close(lhs: &Array2<f64>, rhs: &Array2<f64>, tol: f64, label: &str) {
928        assert_eq!(
929            lhs.dim(),
930            rhs.dim(),
931            "{label} shape mismatch: left {:?}, right {:?}",
932            lhs.dim(),
933            rhs.dim()
934        );
935        for i in 0..lhs.nrows() {
936            for j in 0..lhs.ncols() {
937                assert!(
938                    (lhs[[i, j]] - rhs[[i, j]]).abs() <= tol,
939                    "{label} mismatch at ({i}, {j}): {} vs {}",
940                    lhs[[i, j]],
941                    rhs[[i, j]]
942                );
943            }
944        }
945    }
946
947    fn assert_transform_close(
948        lhs: &ScaleDeviationTransform,
949        rhs: &ScaleDeviationTransform,
950        tol: f64,
951    ) {
952        assert_eq!(lhs.non_intercept_start, rhs.non_intercept_start);
953        assert_matrix_close(
954            &lhs.projection_coef,
955            &rhs.projection_coef,
956            tol,
957            "projection coefficients",
958        );
959        assert_eq!(
960            lhs.weighted_column_mean.len(),
961            rhs.weighted_column_mean.len()
962        );
963        assert_eq!(lhs.rescale.len(), rhs.rescale.len());
964        for j in 0..lhs.weighted_column_mean.len() {
965            assert!(
966                (lhs.weighted_column_mean[j] - rhs.weighted_column_mean[j]).abs() <= tol,
967                "weighted column mean mismatch at {j}: {} vs {}",
968                lhs.weighted_column_mean[j],
969                rhs.weighted_column_mean[j]
970            );
971            assert!(
972                (lhs.rescale[j] - rhs.rescale[j]).abs() <= tol,
973                "rescale mismatch at {j}: {} vs {}",
974                lhs.rescale[j],
975                rhs.rescale[j]
976            );
977        }
978    }
979
980    #[test]
981    fn scale_deviation_transform_overdetermined() {
982        let n = 1000;
983        let p_primary = 10;
984        let p_noise = 5;
985
986        let mut primary = Array2::<f64>::zeros((n, p_primary));
987        let mut noise = Array2::<f64>::zeros((n, p_noise));
988        for i in 0..n {
989            for j in 0..p_primary {
990                primary[[i, j]] = ((i * 3 + j * 11) as f64 * 0.1).sin();
991            }
992            for j in 0..p_noise {
993                noise[[i, j]] = ((i * 5 + j * 13) as f64 * 0.1).cos();
994            }
995        }
996        noise.column_mut(0).fill(1.0);
997        let weights = Array1::<f64>::ones(n);
998
999        let transform = build_scale_deviation_transform(&primary, &noise, &weights, 1)
1000            .expect("transform should succeed for overdetermined inputs");
1001        let transformed = apply_scale_deviation_transform(&primary, &noise, &transform)
1002            .expect("apply should succeed for overdetermined inputs");
1003
1004        assert_eq!(transform.projection_coef.dim(), (p_primary, p_noise));
1005        assert_eq!(transformed.dim(), (n, p_noise));
1006        assert!(transformed.iter().all(|v| v.is_finite()));
1007        assert!(transformed.column(0).iter().all(|&v| v == 1.0));
1008
1009        let primary_design =
1010            DesignMatrix::Dense(gam_linalg::matrix::DenseDesignMatrix::from(primary.clone()));
1011        let noise_design =
1012            DesignMatrix::Dense(gam_linalg::matrix::DenseDesignMatrix::from(noise.clone()));
1013        let non_intercept_start = infer_non_intercept_start_design(&noise_design, &weights)
1014            .expect("design-native non-intercept detection should succeed");
1015        assert_eq!(non_intercept_start, 1);
1016        let design_transform = build_scale_deviation_transform_design(
1017            &primary_design,
1018            &noise_design,
1019            &weights,
1020            non_intercept_start,
1021        )
1022        .expect("design-native transform should succeed");
1023        let transformed_design =
1024            build_scale_deviation_operator(primary_design, noise_design, &design_transform)
1025                .expect("design-native operator should build")
1026                .to_dense();
1027
1028        assert_eq!(design_transform.projection_coef.dim(), (p_primary, p_noise));
1029        assert_eq!(transformed_design.dim(), transformed.dim());
1030        assert_transform_close(&transform, &design_transform, 1e-10);
1031        assert_matrix_close(
1032            &transformed_design,
1033            &transformed,
1034            1e-8,
1035            "transformed design",
1036        );
1037    }
1038
1039    #[test]
1040    fn scale_deviation_transform_rank_deficient_primary_matches_design_path() {
1041        let n = 384;
1042        let p_primary = 4;
1043        let p_noise = 4;
1044        let mut primary = Array2::<f64>::zeros((n, p_primary));
1045        let mut noise = Array2::<f64>::zeros((n, p_noise));
1046        let mut weights = Array1::<f64>::zeros(n);
1047
1048        for i in 0..n {
1049            let t = i as f64 / n as f64;
1050            let wobble = (17.0 * t).sin();
1051            primary[[i, 0]] = 1.0;
1052            primary[[i, 1]] = t;
1053            primary[[i, 2]] = t + 1e-12 * wobble;
1054            primary[[i, 3]] = 2.0 * t - 1e-12 * wobble;
1055
1056            noise[[i, 0]] = 1.0;
1057            noise[[i, 1]] = 0.7 * t + 0.2 * (9.0 * t).cos();
1058            noise[[i, 2]] = primary[[i, 1]] - primary[[i, 2]] + 0.1 * (13.0 * t).sin();
1059            noise[[i, 3]] = 0.5 * primary[[i, 3]] + 0.3 * (5.0 * t).cos();
1060
1061            weights[i] = if i % 17 == 0 {
1062                0.0
1063            } else {
1064                0.5 + (11.0 * t).sin().abs()
1065            };
1066        }
1067
1068        let transform = build_scale_deviation_transform(&primary, &noise, &weights, 1)
1069            .expect("dense transform should succeed for ill-conditioned primary");
1070        let transformed = apply_scale_deviation_transform(&primary, &noise, &transform)
1071            .expect("dense apply should succeed for ill-conditioned primary");
1072
1073        let primary_design =
1074            DesignMatrix::Dense(gam_linalg::matrix::DenseDesignMatrix::from(primary.clone()));
1075        let noise_design =
1076            DesignMatrix::Dense(gam_linalg::matrix::DenseDesignMatrix::from(noise.clone()));
1077        let non_intercept_start = infer_non_intercept_start_design(&noise_design, &weights)
1078            .expect("design-native non-intercept detection should succeed");
1079        assert_eq!(non_intercept_start, 1);
1080
1081        let design_transform = build_scale_deviation_transform_design(
1082            &primary_design,
1083            &noise_design,
1084            &weights,
1085            non_intercept_start,
1086        )
1087        .expect("design-native transform should succeed for ill-conditioned primary");
1088        let transformed_design =
1089            build_scale_deviation_operator(primary_design, noise_design, &design_transform)
1090                .expect("design-native operator should build for ill-conditioned primary")
1091                .to_dense();
1092
1093        assert_transform_close(&transform, &design_transform, 1e-10);
1094        assert_matrix_close(
1095            &transformed_design,
1096            &transformed,
1097            1e-8,
1098            "ill-conditioned transformed design",
1099        );
1100    }
1101
1102    #[test]
1103    fn choose_scale_projection_ridge_alpha_scales_with_sigma_max() {
1104        // Truncation tolerance is `RCOND_FLOOR * sigma_max` whenever the
1105        // leverage cap is looser (which it always is for the default 1e8
1106        // value), so alpha = (RCOND_FLOOR * sigma_max)^2.
1107        let alpha_unit = choose_scale_projection_ridge_alpha(&[1.0, 0.5, 1e-6]);
1108        let expected_unit = SCALE_PROJECTION_REPLAY_RCOND_FLOOR.powi(2);
1109        assert!(alpha_unit > 0.0);
1110        assert!(
1111            (alpha_unit - expected_unit).abs() < 1e-24,
1112            "alpha should be {expected_unit:e} for sigma_max=1, got {alpha_unit}"
1113        );
1114
1115        let alpha_scaled = choose_scale_projection_ridge_alpha(&[100.0, 1.0]);
1116        let expected_scaled = (SCALE_PROJECTION_REPLAY_RCOND_FLOOR * 100.0).powi(2);
1117        assert!(
1118            (alpha_scaled - expected_scaled).abs() < 1e-18,
1119            "alpha should be {expected_scaled:e} for sigma_max=100, got {alpha_scaled}"
1120        );
1121        // Scales as sigma_max^2.
1122        assert!(
1123            (alpha_scaled / alpha_unit - 1.0e4).abs() < 1e-6,
1124            "alpha should scale as sigma_max^2; got ratio {}",
1125            alpha_scaled / alpha_unit
1126        );
1127
1128        let alpha_floor = choose_scale_projection_ridge_alpha(&[]);
1129        assert_eq!(alpha_floor, 0.0);
1130    }
1131
1132    #[test]
1133    fn ridge_replay_continuous_under_input_sweep() {
1134        // A near-collinear primary design plus a sweepable perturbation column
1135        // would, under the old hard coefficient cap, jump discontinuously when
1136        // the cap kicks in. With a fixed SVD cutoff, the replayed coefficient
1137        // is a linear function of the input perturbation.
1138        let n = 64;
1139        let mut primary = Array2::<f64>::zeros((n, 3));
1140        let mut noise = Array2::<f64>::zeros((n, 2));
1141        let weights = Array1::<f64>::ones(n);
1142        for i in 0..n {
1143            let t = i as f64 / n as f64;
1144            primary[[i, 0]] = 1.0;
1145            primary[[i, 1]] = t;
1146            // Near-collinear with col 1 — this is the high-gain direction.
1147            primary[[i, 2]] = t + 1e-9 * (5.0 * t).sin();
1148            noise[[i, 0]] = 1.0;
1149            noise[[i, 1]] = (0.4 * t).cos();
1150        }
1151
1152        // Sweep: gradually scale one noise entry; record the corresponding
1153        // projected coefficient cell. Numerical first differences should be
1154        // bounded because the fixed projection operator is linear in the input.
1155        let mut last: Option<f64> = None;
1156        let mut max_step: f64 = 0.0;
1157        for k in 0..50 {
1158            let s = k as f64 / 49.0;
1159            let mut perturbed = noise.clone();
1160            for i in 0..n {
1161                perturbed[[i, 1]] += s;
1162            }
1163            let transform = build_scale_deviation_transform(&primary, &perturbed, &weights, 1)
1164                .expect("ridge transform should succeed under input sweep");
1165            let val = transform.projection_coef[[2, 1]];
1166            if let Some(prev) = last {
1167                let step = (val - prev).abs();
1168                max_step = max_step.max(step);
1169            }
1170            last = Some(val);
1171        }
1172        // Step bound: with 50 samples over a unit sweep, a smooth dependence
1173        // produces uniform tiny jumps.  The old coefficient cap would emit a
1174        // single huge step at the cap boundary, easily blowing 1.0 here.
1175        assert!(
1176            max_step < 0.5,
1177            "replay coefficient sweep should be continuous, got max step {max_step}"
1178        );
1179    }
1180
1181    #[test]
1182    fn ridge_replay_noise_free_is_near_identity() {
1183        // When the noise design lives in the column span of the primary
1184        // design and W^{1/2} X is well-conditioned, the retained singular
1185        // directions use the exact inverse and the residual after subtracting
1186        // the projected fit is at numerical zero.
1187        let n = 128;
1188        let p_primary = 4;
1189        let p_noise = 3;
1190        let mut primary = Array2::<f64>::zeros((n, p_primary));
1191        let mut noise = Array2::<f64>::zeros((n, p_noise));
1192        let weights = Array1::<f64>::ones(n);
1193        for i in 0..n {
1194            let t = i as f64 / n as f64;
1195            primary[[i, 0]] = 1.0;
1196            primary[[i, 1]] = t;
1197            primary[[i, 2]] = (3.0 * t).sin();
1198            primary[[i, 3]] = (2.0 * t - 0.4).powi(2);
1199            noise[[i, 0]] = 1.0;
1200            // Linear combinations of primary cols so the projection should
1201            // recover them through the retained exact-inverse directions.
1202            noise[[i, 1]] = 0.7 * primary[[i, 1]] - 0.3 * primary[[i, 2]];
1203            noise[[i, 2]] = 0.2 * primary[[i, 3]] + 0.1 * primary[[i, 1]];
1204        }
1205
1206        let transform = build_scale_deviation_transform(&primary, &noise, &weights, 1)
1207            .expect("transform should succeed");
1208        let transformed = apply_scale_deviation_transform(&primary, &noise, &transform)
1209            .expect("apply should succeed");
1210
1211        // Pass-through column unaffected.
1212        for i in 0..n {
1213            assert_eq!(transformed[[i, 0]], 1.0);
1214        }
1215        // Active columns: residuals should be near zero because the relevant
1216        // singular directions are retained and inverted exactly. The design is
1217        // well-conditioned, so 1e-6 is a safe envelope for roundoff.
1218        for j in 1..p_noise {
1219            for i in 0..n {
1220                assert!(
1221                    transformed[[i, j]].abs() < 1e-6,
1222                    "noise-free residual should be near zero at ({i},{j}), got {}",
1223                    transformed[[i, j]]
1224                );
1225            }
1226        }
1227        assert!(transform.projection_ridge_alpha > 0.0);
1228    }
1229
1230    #[test]
1231    fn scale_transform_payload_round_trips_alpha() {
1232        let n = 64;
1233        let mut primary = Array2::<f64>::zeros((n, 3));
1234        let mut noise = Array2::<f64>::zeros((n, 2));
1235        let weights = Array1::<f64>::ones(n);
1236        for i in 0..n {
1237            let t = i as f64 / n as f64;
1238            primary[[i, 0]] = 1.0;
1239            primary[[i, 1]] = t;
1240            primary[[i, 2]] = (4.0 * t).cos();
1241            noise[[i, 0]] = 1.0;
1242            noise[[i, 1]] = (2.0 * t).sin();
1243        }
1244        let transform = build_scale_deviation_transform(&primary, &noise, &weights, 1)
1245            .expect("transform should succeed");
1246
1247        let projection: Vec<Vec<f64>> = transform
1248            .projection_coef
1249            .rows()
1250            .into_iter()
1251            .map(|row| row.to_vec())
1252            .collect();
1253        let center = transform.weighted_column_mean.to_vec();
1254        let scale = transform.rescale.to_vec();
1255        let restored = scale_transform_from_payload(
1256            &Some(projection),
1257            &Some(center),
1258            &Some(scale),
1259            Some(transform.non_intercept_start),
1260            Some(transform.projection_ridge_alpha),
1261        )
1262        .expect("payload round-trip should succeed")
1263        .expect("payload should produce a transform");
1264        assert_eq!(
1265            restored.projection_ridge_alpha, transform.projection_ridge_alpha,
1266            "alpha must round-trip exactly through payload serialization"
1267        );
1268    }
1269}