Skip to main content

linreg_core/regularized/
preprocess.rs

1//! Data preprocessing for regularized regression.
2//!
3//! This module provides standardization utilities that match glmnet output behavior:
4//!
5//! - Predictors are centered and scaled (if enabled)
6//! - The intercept column is not penalized, so it's handled specially
7//! - Coefficients can be unstandardized back to the original scale
8//! - Observation weights are supported for weighted regression
9//!
10//! # Standardization Convention
11//!
12//! The scaling factor used is `sqrt(sum(x²) / n)`, which gives unit variance
13//! under the 1/n convention (matching the glmnet paper).
14//!
15//! # Weighted Standardization
16//!
17//! When weights are provided, they are first normalized to sum to 1:
18//! `weights_normalized = w / sum(w)`. Then weighted means and variances are computed.
19
20use crate::linalg::Matrix;
21
22/// Information stored during standardization, used to unstandardize coefficients.
23///
24/// This struct captures all the information needed to transform coefficients
25/// from the standardized space back to the original data scale.
26///
27/// # Fields
28///
29/// * `x_mean` - Mean of each predictor column (length p)
30/// * `x_scale` - Scale factor for each predictor column (length p)
31/// * `y_mean` - Mean of response variable
32/// * `y_scale` - Scale factor for response (optional, used for lambda path)
33/// * `intercept` - Whether an intercept term was included
34/// * `standardized_x` - Whether X was standardized
35/// * `standardized_y` - Whether y was standardized
36///
37/// # Example
38///
39/// ```
40/// # use linreg_core::regularized::preprocess::StandardizationInfo;
41/// let info = StandardizationInfo {
42///     x_mean: vec![0.0, 5.0],
43///     x_scale: vec![1.0, 2.0],
44///     column_squared_norms: vec![1.0, 1.0],
45///     y_mean: 10.0,
46///     y_scale: Some(3.0),
47///     y_scale_before_sqrt_weights_normalized: Some(3.0),
48///     intercept: true,
49///     standardized_x: true,
50///     standardized_y: false,
51/// };
52///
53/// assert_eq!(info.x_mean.len(), 2);
54/// assert!(info.intercept);
55/// ```
56#[derive(Clone, Debug)]
57pub struct StandardizationInfo {
58    /// Mean of each predictor column
59    pub x_mean: Vec<f64>,
60    /// Scale factor for each predictor column
61    pub x_scale: Vec<f64>,
62    /// Squared norm of each predictor column after standardization.
63    /// This is used in the coordinate descent update denominator.
64    /// - With intercept and standardize: column_squared_norms\[j\] = 1.0 (unit norm after centering)
65    /// - Without intercept and standardize: column_squared_norms\[j\] = 1.0 + x_squared_mean/x_centered_variance (glmnet formula)
66    /// - Without standardize: column_squared_norms\[j\] = ||x_j||^2 (actual squared norm)
67    pub column_squared_norms: Vec<f64>,
68    /// Mean of response variable
69    pub y_mean: f64,
70    /// Scale factor for response (for lambda path construction)
71    /// This is the norm AFTER sqrt_weights_normalized transformation and centering: sqrt(sum((sqrt_weights_normalized*(y-ym))^2))
72    pub y_scale: Option<f64>,
73    /// Scale factor for response BEFORE sqrt_weights_normalized transformation: sqrt(sum((y-ym)^2))
74    /// This is used for lambda scaling between original and standardized data
75    pub y_scale_before_sqrt_weights_normalized: Option<f64>,
76    /// Whether an intercept was included
77    pub intercept: bool,
78    /// Whether X was standardized
79    pub standardized_x: bool,
80    /// Whether y was standardized
81    pub standardized_y: bool,
82}
83
84/// Options for standardization.
85///
86/// # Fields
87///
88/// * `intercept` - Whether to include/center an intercept (default: true)
89/// * `standardize_x` - Whether to standardize predictors (default: true)
90/// * `standardize_y` - Whether to standardize response (default: false)
91/// * `weights` - Optional observation weights (default: None)
92///   If provided, weights are normalized to sum to 1 before use.
93///
94/// # Note
95///
96/// Setting `standardize_y` to `true` is mainly useful when you want to match
97/// glmnet's lambda sequence exactly. For single-lambda fits, you typically
98/// don't need to standardize y.
99///
100/// # Example
101///
102/// ```
103/// # use linreg_core::regularized::preprocess::StandardizeOptions;
104/// let opts = StandardizeOptions {
105///     intercept: true,
106///     standardize_x: true,
107///     standardize_y: false,
108///     weights: None,
109/// };
110///
111/// assert!(opts.intercept);
112/// assert!(opts.standardize_x);
113/// ```
114#[derive(Clone, Debug)]
115pub struct StandardizeOptions {
116    /// Whether to include an intercept (and center X)
117    pub intercept: bool,
118    /// Whether to standardize predictor columns
119    pub standardize_x: bool,
120    /// Whether to standardize the response variable
121    pub standardize_y: bool,
122    /// Optional observation weights (normalized to sum to 1)
123    pub weights: Option<Vec<f64>>,
124}
125
126impl Default for StandardizeOptions {
127    fn default() -> Self {
128        StandardizeOptions {
129            intercept: true,
130            standardize_x: true,
131            standardize_y: false,
132            weights: None,
133        }
134    }
135}
136
137/// Standardizes X and y for regularized regression (glmnet-compatible).
138///
139/// This function performs the same standardization as glmnet with
140/// `standardize=TRUE`. The first column of X is assumed to be the intercept
141/// (all ones) and is NOT standardized.
142///
143/// # Arguments
144///
145/// * `x` - Design matrix (n × p). First column should be intercept if `intercept=true`.
146/// * `y` - Response vector (n elements)
147/// * `options` - Standardization options (including optional observation weights)
148///
149/// # Returns
150///
151/// A tuple `(x_standardized, y_standardized, info)` where:
152/// - `x_standardized` is the standardized design matrix
153/// - `y_standardized` is the (optionally) standardized response
154/// - `info` contains the standardization parameters for unstandardization
155///
156/// # Standardization Details
157///
158/// ## Unweighted case:
159/// For the intercept column (first column, if present):
160/// - Not centered (stays as ones)
161/// - Not scaled
162///
163/// For other columns (if `standardize_x=true`):
164/// - Centered: `x_centered = x - mean(x)`
165/// - Scaled: `x_scaled = x_centered / sqrt(sum(x²))`
166///
167/// For y (if `standardize_y=true`):
168/// - Centered: `y_centered = y - mean(y)`
169/// - Scaled: `y_scaled = y_centered / sqrt(sum(y²))`
170///
171/// ## Weighted case:
172/// Weights are first normalized: `weights_normalized = w / sum(w)`, then `sqrt_weights_normalized = sqrt(weights_normalized)`
173/// - Weighted mean: `ym = sum(w * y) / sum(w) = sum(weights_normalized * y)`
174/// - Weighted variance: `sum(w * (y - ym)^2) / sum(w)`
175/// - Data transformed by `sqrt_weights_normalized`: `y_new = sqrt_weights_normalized * (y - ym)`, then scaled
176#[allow(clippy::needless_range_loop)]
177pub fn standardize_xy(
178    x: &Matrix,
179    y: &[f64],
180    options: &StandardizeOptions,
181) -> (Matrix, Vec<f64>, StandardizationInfo) {
182    let n = x.rows;
183    let p = x.cols;
184
185    // Validate weights if provided
186    if let Some(ref w) = options.weights {
187        if w.len() != n {
188            return (
189                Matrix::new(n, p, vec![0.0; n * p]),
190                vec![0.0; n],
191                StandardizationInfo {
192                    x_mean: vec![0.0; p],
193                    x_scale: vec![1.0; p],
194                    column_squared_norms: vec![0.0; p],
195                    y_mean: 0.0,
196                    y_scale: None,
197                    y_scale_before_sqrt_weights_normalized: None,
198                    intercept: options.intercept,
199                    standardized_x: options.standardize_x,
200                    standardized_y: options.standardize_y,
201                },
202            );
203        }
204        if w.iter().any(|&wi| wi < 0.0) {
205            panic!("Weights must be non-negative");
206        }
207    }
208
209    // Prepare normalized weights and sqrt(weights)
210    // w = w / sum(w) then sqrt_weights_normalized = sqrt(w)
211    let (weights_normalized, sqrt_weights_normalized): (Vec<f64>, Vec<f64>) = if let Some(ref w) = options.weights {
212        let w_sum: f64 = w.iter().sum();
213        if w_sum > 0.0 {
214            let weights_normalized_vec: Vec<f64> = w.iter().map(|&wi| wi / w_sum).collect();
215            let sqrt_weights_normalized_vec: Vec<f64> = weights_normalized_vec.iter().map(|&wi| wi.sqrt()).collect();
216            (weights_normalized_vec, sqrt_weights_normalized_vec)
217        } else {
218            (vec![0.0; n], vec![0.0; n])
219        }
220    } else {
221        // No weights: use uniform weights
222        let w_uniform = vec![1.0 / n as f64; n];
223        let sqrt_weights_normalized_uniform = vec![1.0 / (n as f64).sqrt(); n];
224        (w_uniform, sqrt_weights_normalized_uniform)
225    };
226
227    let mut x_standardized = x.clone();
228    let mut y_standardized = y.to_vec();
229
230    let mut x_mean = vec![0.0; p];
231    let mut x_scale = vec![1.0; p];
232    let mut column_squared_norms = vec![0.0; p];  // Column squared norms for coordinate descent
233
234    let y_mean = if options.intercept && !y.is_empty() {
235        // Weighted mean: ym = sum(w * y)
236        weights_normalized.iter().zip(y.iter()).map(|(&w, &yi)| w * yi).sum()
237    } else {
238        0.0
239    };
240
241    // GLMNET: y is ALWAYS scaled to unit norm!
242    // This is critical for correct lambda_max computation
243    let (y_scale, y_scale_before_sqrt_weights_normalized) = if options.intercept {
244        // WITH INTERCEPT: Center y, then scale to unit norm
245        // First compute y_scale_before_sqrt_weights_normalized (centered but not sqrt_weights_normalized-transformed)
246        let y_centered: Vec<f64> = y.iter().map(|&yi| yi - y_mean).collect();
247        let y_ss_before_sqrt_weights_normalized: f64 = y_centered.iter().map(|&yi| yi * yi).sum();
248        let y_scale_before_sqrt_weights_normalized_val = y_ss_before_sqrt_weights_normalized.sqrt();
249
250        // Center y: y_new = sqrt_weights_normalized * (y - ym)
251        for (yi, &sqrt_weight) in y_standardized.iter_mut().zip(&sqrt_weights_normalized) {
252            *yi = sqrt_weight * (*yi - y_mean);
253        }
254
255        // Scale to unit norm (GLMNET always does this!)
256        let y_ss: f64 = y_standardized.iter().map(|&yi| yi * yi).sum();
257        let y_scale_val = y_ss.sqrt();
258        if y_scale_val > 0.0 {
259            for yi in y_standardized.iter_mut() {
260                *yi /= y_scale_val;
261            }
262        }
263        (Some(y_scale_val), Some(y_scale_before_sqrt_weights_normalized_val))
264    } else {
265        // WITHOUT INTERCEPT: Don't center y, but DO scale to unit norm (GLMNET output behavior!)
266        // y_new = sqrt_weights_normalized * y, then y = y / ||y||
267        for (yi, &sqrt_weight) in y_standardized.iter_mut().zip(&sqrt_weights_normalized) {
268            *yi *= sqrt_weight;
269        }
270        let y_ss: f64 = y_standardized.iter().map(|&yi| yi * yi).sum();
271        let y_scale_val = y_ss.sqrt();
272        if y_scale_val > 0.0 {
273            for yi in y_standardized.iter_mut() {
274                *yi /= y_scale_val;
275            }
276        }
277        (Some(y_scale_val), Some(y_scale_val))  // y_scale_before_sqrt_weights_normalized = y_scale when no centering
278    };
279
280    // Standardize X columns
281    // If intercept is present, first column is NOT standardized (it's the intercept column)
282    let first_penalized_column_index = if options.intercept { 1 } else { 0 };
283
284    if options.intercept {
285        // WITH INTERCEPT (intercept column not standardized)
286        for j in first_penalized_column_index..p {
287            // Compute weighted column mean and center
288            let col_mean: f64 = (0..n)
289                .map(|i| x_standardized.get(i, j) * weights_normalized[i])
290                .sum();
291            x_mean[j] = col_mean;
292
293            // Center the column and apply sqrt_weights_normalized transformation
294            // x_new = sqrt_weights_normalized * (x - xm)
295            for i in 0..n {
296                let val = sqrt_weights_normalized[i] * (x_standardized.get(i, j) - col_mean);
297                x_standardized.set(i, j, val);
298            }
299
300            // Compute squared norm
301            let col_squared_norm_val: f64 = (0..n)
302                .map(|i| {
303                    let val = x_standardized.get(i, j);
304                    val * val
305                })
306                .sum();
307
308            if options.standardize_x {
309                // Scale to unit norm
310                let col_scale = col_squared_norm_val.sqrt();
311                if col_scale > 0.0 {
312                    for i in 0..n {
313                        let val = x_standardized.get(i, j) / col_scale;
314                        x_standardized.set(i, j, val);
315                    }
316                    x_scale[j] = col_scale;
317                    column_squared_norms[j] = 1.0;  // Unit norm
318                }
319            } else {
320                // No standardization - column_squared_norms stays as the actual squared norm
321                column_squared_norms[j] = col_squared_norm_val;
322                x_scale[j] = 1.0;
323            }
324        }
325    } else {
326        // WITHOUT INTERCEPT (no centering)
327        for j in first_penalized_column_index..p {
328            x_mean[j] = 0.0;  // No centering
329
330            // Apply sqrt_weights_normalized transformation
331            for i in 0..n {
332                let val = sqrt_weights_normalized[i] * x_standardized.get(i, j);
333                x_standardized.set(i, j, val);
334            }
335
336            // Compute squared norm after sqrt_weights_normalized transformation
337            let col_squared_norm_val: f64 = (0..n)
338                .map(|i| {
339                    let val = x_standardized.get(i, j);
340                    val * val
341                })
342                .sum();
343
344            if options.standardize_x {
345                // GLMNET special formula for no-intercept case:
346                // x_squared_mean = dot_product(sqrt_weights_normalized, x)^2  (squared mean)
347                // x_centered_variance = col_squared_norm - x_squared_mean  (variance-like quantity)
348                // xs = sqrt(x_centered_variance)
349                // column_squared_norms_final = 1.0 + x_squared_mean / x_centered_variance
350                let x_squared_mean: f64 = (0..n)
351                    .map(|i| sqrt_weights_normalized[i] * x_standardized.get(i, j))
352                    .sum::<f64>().powi(2);
353                let x_centered_variance = col_squared_norm_val - x_squared_mean;
354
355                if x_centered_variance > 0.0 {
356                    let col_scale = x_centered_variance.sqrt();
357                    // Scale by col_scale (NOT by sqrt(col_squared_norm_val))
358                    for i in 0..n {
359                        let val = x_standardized.get(i, j) / col_scale;
360                        x_standardized.set(i, j, val);
361                    }
362                    x_scale[j] = col_scale;
363                    column_squared_norms[j] = 1.0 + x_squared_mean / x_centered_variance;  // GLMNET formula
364                } else {
365                    column_squared_norms[j] = 1.0;
366                    x_scale[j] = 1.0;
367                }
368            } else {
369                // No standardization
370                column_squared_norms[j] = col_squared_norm_val;
371                x_scale[j] = 1.0;
372            }
373        }
374    }
375
376    // If intercept column exists, set its scale to 1.0 (not penalized)
377    if options.intercept && p > 0 {
378        x_scale[0] = 1.0;
379        x_mean[0] = 0.0; // Intercept column has no "mean" to subtract
380        column_squared_norms[0] = 1.0;  // Intercept column is not penalized
381    }
382
383    let info = StandardizationInfo {
384        x_mean,
385        x_scale,
386        column_squared_norms,
387        y_mean,
388        y_scale,
389        y_scale_before_sqrt_weights_normalized,
390        intercept: options.intercept,
391        standardized_x: options.standardize_x,
392        standardized_y: options.standardize_y,
393    };
394
395    (x_standardized, y_standardized, info)
396}
397
398/// Unstandardizes coefficients from the standardized space back to original scale.
399///
400/// This reverses the standardization transformation to get coefficients that
401/// can be applied to the original (unscaled) data.
402///
403/// # Arguments
404///
405/// * `coefficients_standardized` - Coefficients in standardized space (length p)
406/// * `info` - Standardization information from [`standardize_xy`]
407///
408/// # Returns
409///
410/// A tuple `(beta0, beta_slopes)` where:
411/// - `beta0` is the intercept on the original scale
412/// - `beta_slopes` are the slope coefficients only (excluding intercept column coefficient)
413///
414/// # Unstandardization Formula
415///
416/// For non-intercept coefficients:
417/// ```text
418/// β_original[j] = (y_scale * β_std[j]) / x_scale[j]
419/// ```
420///
421/// For the intercept:
422/// ```text
423/// β₀ = y_mean - Σⱼ x_mean[j] * β_original[j]
424/// ```
425///
426/// If y was not standardized, `y_scale = 1`.
427/// If X was not standardized, `x_scale[j] = 1`.
428///
429/// # Note
430///
431/// If `intercept=true` in the info, `coefficients_standardized[0]` is assumed to be the intercept
432/// coefficient (which is already 0 in the standardized space since X was centered).
433/// The returned `beta_slopes` will NOT include this zeroed coefficient - only actual
434/// slope coefficients are returned.
435///
436/// # Example
437///
438/// ```
439/// # use linreg_core::regularized::preprocess::{unstandardize_coefficients, StandardizationInfo};
440/// let info = StandardizationInfo {
441///     x_mean: vec![0.0, 5.0],
442///     x_scale: vec![1.0, 2.0],
443///     column_squared_norms: vec![1.0, 1.0],
444///     y_mean: 10.0,
445///     y_scale: Some(3.0),
446///     y_scale_before_sqrt_weights_normalized: Some(3.0),
447///     intercept: true,
448///     standardized_x: true,
449///     standardized_y: false,
450/// };
451///
452/// // Standardized coefficients: [intercept=0, slope1=2.0]
453/// let coefficients_standardized = vec![0.0, 2.0];
454/// let (beta0, beta_slopes) = unstandardize_coefficients(&coefficients_standardized, &info);
455///
456/// // slope_original = (y_scale * slope_std) / x_scale[1]
457/// //                 = (3.0 * 2.0) / 2.0 = 3.0
458/// assert!((beta_slopes[0] - 3.0).abs() < 0.01);
459/// ```
460#[allow(clippy::needless_range_loop)]
461pub fn unstandardize_coefficients(coefficients_standardized: &[f64], info: &StandardizationInfo) -> (f64, Vec<f64>) {
462    let p = coefficients_standardized.len();
463    let y_scale = info.y_scale.unwrap_or(1.0);
464
465    // Determine where slope coefficients start in coefficients_standardized
466    let start_idx = if info.intercept { 1 } else { 0 };
467    let n_slopes = p - start_idx;
468
469    // Unstandardize slope coefficients only (exclude intercept column coefficient)
470    // NOTE: X is ALWAYS standardized for the solver, so we always apply the unstandardization formula.
471    // The user's `standardize_x` option doesn't affect the internal computation, only the
472    // interpretation of results.
473    let mut beta_slopes = vec![0.0; n_slopes];
474    for j in start_idx..p {
475        let slope_idx = j - start_idx;
476        // Standard unstandardization: beta_original = (y_scale * coefficients_standardized) / x_scale
477        // This converts from the standardized space back to original data scale
478        beta_slopes[slope_idx] = (y_scale * coefficients_standardized[j]) / info.x_scale[j];
479    }
480
481    // Compute intercept on original scale
482    // beta0 = y_mean - sum(x_mean[j] * beta_slopes[j-1]) for j in 1..p
483    let beta0 = if info.intercept {
484        let mut sum = 0.0;
485        for j in 1..p {
486            sum += info.x_mean[j] * beta_slopes[j - 1];
487        }
488        info.y_mean - sum
489    } else {
490        0.0
491    };
492
493    (beta0, beta_slopes)
494}
495
496/// Computes predictions using unstandardized coefficients.
497///
498/// # Arguments
499///
500/// * `x_new` - New data matrix (n_new × p, with intercept column if applicable)
501/// * `beta0` - Intercept on original scale
502/// * `beta` - Slope coefficients on original scale (does NOT include intercept column coefficient)
503///
504/// # Returns
505///
506/// Predictions for each row in x_new.
507///
508/// # Note
509///
510/// If `x_new` has an intercept column (first column of all ones), `beta` should have
511/// `p - 1` elements corresponding to the non-intercept columns. If `x_new` has no
512/// intercept column, `beta` should have `p` elements.
513///
514/// # Example
515///
516/// ```
517/// # use linreg_core::regularized::preprocess::predict;
518/// # use linreg_core::linalg::Matrix;
519/// // X matrix with intercept: [[1, 2], [1, 3], [1, 4]]
520/// let x_new = Matrix::new(3, 2, vec![1.0, 2.0, 1.0, 3.0, 1.0, 4.0]);
521/// let beta0 = 1.0;
522/// let beta = vec![2.0];  // One slope coefficient
523///
524/// // predictions[i] = 1.0 + 2.0 * x[i,1]
525/// let preds = predict(&x_new, beta0, &beta);
526/// assert_eq!(preds, vec![5.0, 7.0, 9.0]);
527/// ```
528#[allow(clippy::needless_range_loop)]
529pub fn predict(x_new: &Matrix, beta0: f64, beta: &[f64]) -> Vec<f64> {
530    let n = x_new.rows;
531    let p = x_new.cols;
532
533    let mut predictions = vec![0.0; n];
534
535    // Determine if there's an intercept column based on beta length
536    // If beta has one fewer element than columns, first column is intercept
537    let has_intercept_col = beta.len() == p - 1;
538    let first_penalized_column_index = if has_intercept_col { 1 } else { 0 };
539
540    for i in 0..n {
541        let mut sum = beta0;
542        for (j, &beta_j) in beta.iter().enumerate() {
543            let col = first_penalized_column_index + j;
544            if col < p {
545                sum += x_new.get(i, col) * beta_j;
546            }
547        }
548        predictions[i] = sum;
549    }
550
551    predictions
552}
553
554#[cfg(test)]
555mod tests {
556    use super::*;
557
558    #[test]
559    fn test_standardize_xy_with_intercept() {
560        // Simple test data
561        let x_data = vec![1.0, 2.0, 3.0, 1.0, 4.0, 6.0, 1.0, 6.0, 9.0];
562        let x = Matrix::new(3, 3, x_data);
563        let y = vec![3.0, 5.0, 7.0];
564
565        let options = StandardizeOptions {
566            intercept: true,
567            standardize_x: true,
568            standardize_y: false,  // Note: y is still scaled to unit norm by glmnet convention
569            weights: None,
570        };
571
572        let (x_standardized, y_standardized, info) = standardize_xy(&x, &y, &options);
573
574        // First column (intercept) should be unchanged
575        assert_eq!(x_standardized.get(0, 0), 1.0);
576        assert_eq!(x_standardized.get(1, 0), 1.0);
577        assert_eq!(x_standardized.get(2, 0), 1.0);
578
579        // GLMNET: y is ALWAYS scaled to unit norm
580        // y_centered = y - y_mean = [-2, 0, 2]
581        // sqrt_weights_normalized-transform: y_sqrt_weights = sqrt_weights_normalized * y_centered = [-2/sqrt(3), 0, 2/sqrt(3)]
582        // Scale to unit norm: y_standardized = y_sqrt_weights / ||y_sqrt_weights|| = [-1/sqrt(2), 0, 1/sqrt(2)]
583        let inv_sqrt2 = 1.0 / (2.0_f64).sqrt();
584        assert!((y_standardized[0] - (-inv_sqrt2)).abs() < 1e-10);
585        assert!((y_standardized[1] - 0.0).abs() < 1e-10);
586        assert!((y_standardized[2] - inv_sqrt2).abs() < 1e-10);
587
588        // x_mean should capture the column means
589        assert_eq!(info.x_mean[0], 0.0); // Intercept column
590        assert!((info.x_mean[1] - 4.0).abs() < 1e-10);
591        assert!((info.x_mean[2] - 6.0).abs() < 1e-10);
592    }
593
594    #[test]
595    fn test_unstandardize_coefficients() {
596        // Create a simple standardization scenario
597        let x_mean = vec![0.0, 4.0, 6.0];
598        let x_scale = vec![1.0, 2.0, 3.0];
599        let column_squared_norms = vec![1.0, 1.0, 1.0];  // Unit norm after standardization
600        let y_mean = 5.0;
601        let y_scale = Some(2.0);
602
603        let info = StandardizationInfo {
604            x_mean: x_mean.clone(),
605            x_scale: x_scale.clone(),
606            column_squared_norms,
607            y_mean,
608            y_scale,
609            y_scale_before_sqrt_weights_normalized: None,
610            intercept: true,
611            standardized_x: true,
612            standardized_y: true,
613        };
614
615        // Coefficients in standardized space: [intercept=0, beta1=1, beta2=2]
616        let coefficients_standardized = vec![0.0, 1.0, 2.0];
617
618        let (beta0, beta_slopes) = unstandardize_coefficients(&coefficients_standardized, &info);
619
620        // Check unstandardization - beta_slopes now only contains slope coefficients
621        // beta_slopes[0] = (y_scale * coefficients_standardized[1]) / x_scale[1] = (2 * 1) / 2 = 1
622        assert!((beta_slopes[0] - 1.0).abs() < 1e-10);
623        // beta_slopes[1] = (y_scale * coefficients_standardized[2]) / x_scale[2] = (2 * 2) / 3 = 4/3
624        assert!((beta_slopes[1] - 4.0 / 3.0).abs() < 1e-10);
625
626        // beta0 = y_mean - sum(x_mean[j] * beta_slopes[j-1])
627        //      = 5 - (4 * 1 + 6 * 4/3) = 5 - 4 - 8 = -7
628        assert!((beta0 - (-7.0)).abs() < 1e-10);
629
630        // Verify beta_slopes has the correct length (only slopes, not intercept col coef)
631        assert_eq!(beta_slopes.len(), 2);
632    }
633
634    #[test]
635    fn test_predict() {
636        // X has intercept column (first col all 1s) plus 2 predictors
637        let x_data = vec![1.0, 2.0, 3.0, 1.0, 4.0, 6.0];
638        let x = Matrix::new(2, 3, x_data);
639
640        // beta0 = 1, beta = [2.0, 3.0] (slope coefficients only, no intercept col coef)
641        let beta0 = 1.0;
642        let beta = vec![2.0, 3.0];
643
644        let preds = predict(&x, beta0, &beta);
645
646        // pred[0] = 1 + 2*2 + 3*3 = 1 + 4 + 9 = 14
647        assert!((preds[0] - 14.0).abs() < 1e-10);
648        // pred[1] = 1 + 2*4 + 3*6 = 1 + 8 + 18 = 27
649        assert!((preds[1] - 27.0).abs() < 1e-10);
650    }
651
652    #[test]
653    fn test_weighted_standardize_xy() {
654        // Simple test data
655        let x_data = vec![1.0, 2.0, 3.0, 1.0, 4.0, 6.0, 1.0, 6.0, 9.0];
656        let x = Matrix::new(3, 3, x_data);
657        let y = vec![3.0, 5.0, 7.0];
658
659        // Weights: give more weight to the middle observation
660        let weights = vec![1.0, 2.0, 1.0];
661
662        let options = StandardizeOptions {
663            intercept: true,
664            standardize_x: true,
665            standardize_y: false,  // Note: y is still scaled to unit norm by glmnet convention
666            weights: Some(weights),
667        };
668
669        let (x_standardized, y_standardized, info) = standardize_xy(&x, &y, &options);
670
671        // First column (intercept) should be unchanged
672        assert_eq!(x_standardized.get(0, 0), 1.0);
673        assert_eq!(x_standardized.get(1, 0), 1.0);
674        assert_eq!(x_standardized.get(2, 0), 1.0);
675
676        // y_mean should be weighted mean
677        // weights normalized: [1/4, 2/4, 1/4] = [0.25, 0.5, 0.25]
678        // weighted mean: 0.25*3 + 0.5*5 + 0.25*7 = 0.75 + 2.5 + 1.75 = 5.0
679        assert!((info.y_mean - 5.0).abs() < 1e-10);
680
681        // GLMNET: y is ALWAYS scaled to unit norm
682        // y_centered = y - y_mean = [-2, 0, 2]
683        // sqrt_weights_normalized = sqrt([0.25, 0.5, 0.25]) = [0.5, ~0.707, 0.5]
684        // y_sqrt_weights = sqrt_weights_normalized * y_centered = [-1, 0, 1]
685        // sum(y_sqrt_weights^2) = 2, so y_scale = sqrt(2)
686        // y_standardized = y_sqrt_weights / y_scale = [-1/sqrt(2), 0, 1/sqrt(2)]
687        let expected_0 = -1.0 / (2.0_f64).sqrt();
688        assert!((y_standardized[0] - expected_0).abs() < 1e-10);
689        assert!((y_standardized[1] - 0.0).abs() < 1e-10);
690        assert!((y_standardized[2] + expected_0).abs() < 1e-10);  // Should be 1/sqrt(2)
691    }
692
693    #[test]
694    fn test_weighted_standardize_uniform_weights() {
695        // Test that uniform weights give same result as no weights
696        let x_data = vec![1.0, 2.0, 3.0, 1.0, 4.0, 6.0];
697        let x = Matrix::new(2, 3, x_data);
698        let y = vec![3.0, 5.0];
699
700        // Uniform weights (should be equivalent to no weights after normalization)
701        let weights = vec![1.0, 1.0];
702
703        let options_with_weights = StandardizeOptions {
704            intercept: true,
705            standardize_x: true,
706            standardize_y: false,
707            weights: Some(weights),
708        };
709
710        let options_no_weights = StandardizeOptions {
711            intercept: true,
712            standardize_x: true,
713            standardize_y: false,
714            weights: None,
715        };
716
717        let (_x_standardized_w, y_standardized_w, info_w) = standardize_xy(&x, &y, &options_with_weights);
718        let (_x_standardized, y_standardized, info) = standardize_xy(&x, &y, &options_no_weights);
719
720        // Results should be the same
721        assert_eq!(info_w.y_mean, info.y_mean);
722        for i in 0..2 {
723            assert!((y_standardized_w[i] - y_standardized[i]).abs() < 1e-10);
724        }
725    }
726}