linreg_core/regularized/
preprocess.rs

1//! Data preprocessing for regularized regression.
2//!
3//! This module provides standardization utilities that match glmnet's behavior:
4//!
5//! - Predictors are centered and scaled (if enabled)
6//! - The intercept column is not penalized, so it's handled specially
7//! - Coefficients can be unstandardized back to the original scale
8//!
9//! # Standardization Convention
10//!
11//! The scaling factor used is `sqrt(sum(x²) / n)`, which gives unit variance
12//! under the 1/n convention (matching glmnet).
13
14use crate::linalg::{vec_mean, Matrix};
15
16/// Information stored during standardization, used to unstandardize coefficients.
17///
18/// This struct captures all the information needed to transform coefficients
19/// from the standardized space back to the original data scale.
20///
21/// # Fields
22///
23/// * `x_mean` - Mean of each predictor column (length p)
24/// * `x_scale` - Scale factor for each predictor column (length p)
25/// * `y_mean` - Mean of response variable
26/// * `y_scale` - Scale factor for response (optional, used for lambda path)
27/// * `intercept` - Whether an intercept term was included
28/// * `standardized_x` - Whether X was standardized
29/// * `standardized_y` - Whether y was standardized
30#[derive(Clone, Debug)]
31pub struct StandardizationInfo {
32    /// Mean of each predictor column
33    pub x_mean: Vec<f64>,
34    /// Scale factor for each predictor column
35    pub x_scale: Vec<f64>,
36    /// Mean of response variable
37    pub y_mean: f64,
38    /// Scale factor for response (for lambda path construction)
39    pub y_scale: Option<f64>,
40    /// Whether an intercept was included
41    pub intercept: bool,
42    /// Whether X was standardized
43    pub standardized_x: bool,
44    /// Whether y was standardized
45    pub standardized_y: bool,
46}
47
48/// Options for standardization.
49///
50/// # Fields
51///
52/// * `intercept` - Whether to include/center an intercept (default: true)
53/// * `standardize_x` - Whether to standardize predictors (default: true)
54/// * `standardize_y` - Whether to standardize response (default: false)
55///
56/// # Note
57///
58/// Setting `standardize_y` to `true` is mainly useful when you want to match
59/// glmnet's lambda sequence exactly. For single-lambda fits, you typically
60/// don't need to standardize y.
61#[derive(Clone, Debug)]
62pub struct StandardizeOptions {
63    /// Whether to include an intercept (and center X)
64    pub intercept: bool,
65    /// Whether to standardize predictor columns
66    pub standardize_x: bool,
67    /// Whether to standardize the response variable
68    pub standardize_y: bool,
69}
70
71impl Default for StandardizeOptions {
72    fn default() -> Self {
73        StandardizeOptions {
74            intercept: true,
75            standardize_x: true,
76            standardize_y: false,
77        }
78    }
79}
80
81/// Standardizes X and y for regularized regression (glmnet-compatible).
82///
83/// This function performs the same standardization as glmnet with
84/// `standardize=TRUE`. The first column of X is assumed to be the intercept
85/// (all ones) and is NOT standardized.
86///
87/// # Arguments
88///
89/// * `x` - Design matrix (n × p). First column should be intercept if `intercept=true`.
90/// * `y` - Response vector (n elements)
91/// * `options` - Standardization options
92///
93/// # Returns
94///
95/// A tuple `(x_std, y_std, info)` where:
96/// - `x_std` is the standardized design matrix
97/// - `y_std` is the (optionally) standardized response
98/// - `info` contains the standardization parameters for unstandardization
99///
100/// # Standardization Details
101///
102/// For the intercept column (first column, if present):
103/// - Not centered (stays as ones)
104/// - Not scaled
105///
106/// For other columns (if `standardize_x=true`):
107/// - Centered: `x_centered = x - mean(x)`
108/// - Scaled: `x_scaled = x_centered / sqrt(sum(x²) / n)`
109///
110/// For y (if `standardize_y=true`):
111/// - Centered: `y_centered = y - mean(y)`
112/// - Scaled: `y_scaled = y_centered / sqrt(sum(y²) / n)`
113pub fn standardize_xy(x: &Matrix, y: &[f64], options: &StandardizeOptions) -> (Matrix, Vec<f64>, StandardizationInfo) {
114    let n = x.rows;
115    let p = x.cols;
116
117    let mut x_std = x.clone();
118    let mut y_std = y.to_vec();
119
120    let mut x_mean = vec![0.0; p];
121    let mut x_scale = vec![1.0; p];
122
123    let y_mean = if options.intercept && !y.is_empty() {
124        vec_mean(y)
125    } else {
126        0.0
127    };
128
129    // Standardize y if requested
130    let y_scale = if options.standardize_y {
131        let y_centered: Vec<f64> = y.iter().map(|&yi| yi - y_mean).collect();
132        let y_var = y_centered.iter().map(|&yi| yi * yi).sum::<f64>() / n as f64;
133        let y_scale_val = y_var.sqrt();
134        if y_scale_val > 0.0 {
135            for yi in y_std.iter_mut() {
136                *yi = (*yi - y_mean) / y_scale_val;
137            }
138        }
139        Some(y_scale_val)
140    } else {
141        None
142    };
143
144    // Standardize X columns
145    // If intercept is present, first column is NOT standardized
146    let start_col = if options.intercept { 1 } else { 0 };
147
148    for j in start_col..p {
149        // Compute column mean
150        let mut col_mean = 0.0;
151        for i in 0..n {
152            col_mean += x_std.get(i, j);
153        }
154        col_mean /= n as f64;
155        x_mean[j] = col_mean;
156
157        if options.standardize_x {
158            // Center the column
159            for i in 0..n {
160                let val = x_std.get(i, j) - col_mean;
161                x_std.set(i, j, val);
162            }
163
164            // Compute scale: sqrt(sum(x²) / n)
165            let mut col_scale_sq = 0.0;
166            for i in 0..n {
167                let val = x_std.get(i, j);
168                col_scale_sq += val * val;
169            }
170            let col_scale = (col_scale_sq / n as f64).sqrt();
171
172            if col_scale > 0.0 {
173                x_scale[j] = col_scale;
174                // Scale the column
175                for i in 0..n {
176                    let val = x_std.get(i, j) / col_scale;
177                    x_std.set(i, j, val);
178                }
179            }
180        } else {
181            // Just center, don't scale
182            x_scale[j] = 1.0;
183        }
184    }
185
186    // If intercept column exists, set its scale to 1.0 (not penalized)
187    if options.intercept && p > 0 {
188        x_scale[0] = 1.0;
189        x_mean[0] = 0.0;  // Intercept column has no "mean" to subtract
190    }
191
192    let info = StandardizationInfo {
193        x_mean,
194        x_scale,
195        y_mean,
196        y_scale,
197        intercept: options.intercept,
198        standardized_x: options.standardize_x,
199        standardized_y: options.standardize_y,
200    };
201
202    (x_std, y_std, info)
203}
204
205/// Unstandardizes coefficients from the standardized space back to original scale.
206///
207/// This reverses the standardization transformation to get coefficients that
208/// can be applied to the original (unscaled) data.
209///
210/// # Arguments
211///
212/// * `beta_std` - Coefficients in standardized space (length p)
213/// * `info` - Standardization information from [`standardize_xy`]
214///
215/// # Returns
216///
217/// A tuple `(beta0, beta_original)` where:
218/// - `beta0` is the intercept on the original scale
219/// - `beta_original` are the slope coefficients on the original scale
220///
221/// # Unstandardization Formula
222///
223/// For non-intercept coefficients:
224/// ```text
225/// β_original[j] = (y_scale * β_std[j]) / x_scale[j]
226/// ```
227///
228/// For the intercept:
229/// ```text
230/// β₀ = y_mean - Σⱼ x_mean[j] * β_original[j]
231/// ```
232///
233/// If y was not standardized, `y_scale = 1`.
234/// If X was not standardized, `x_scale[j] = 1`.
235///
236/// # Note
237///
238/// If `intercept=true` in the info, `beta_std[0]` is assumed to be the intercept
239/// coefficient (which is already 0 in the standardized space since X was centered).
240pub fn unstandardize_coefficients(beta_std: &[f64], info: &StandardizationInfo) -> (f64, Vec<f64>) {
241    let p = beta_std.len();
242    let mut beta_original = vec![0.0; p];
243    let y_scale = info.y_scale.unwrap_or(1.0);
244
245    // Handle intercept: if intercept was used, beta_std[0] is the intercept
246    // In standardized space with centered X, the intercept should be y_mean
247    // But we compute it properly from the formula
248
249    let start_idx = if info.intercept { 1 } else { 0 };
250
251    // Unstandardize non-intercept coefficients
252    for j in start_idx..p {
253        beta_original[j] = (y_scale * beta_std[j]) / info.x_scale[j];
254    }
255
256    // Compute intercept on original scale
257    let beta0 = if info.intercept {
258        let mut sum = 0.0;
259        for j in 1..p {
260            sum += info.x_mean[j] * beta_original[j];
261        }
262        info.y_mean - sum
263    } else {
264        0.0
265    };
266
267    // If intercept was in beta_std, store it separately
268    let intercept_value = if info.intercept {
269        beta0
270    } else {
271        0.0
272    };
273
274    (intercept_value, beta_original)
275}
276
277/// Computes predictions using unstandardized coefficients.
278///
279/// # Arguments
280///
281/// * `x_new` - New data matrix (n_new × p, with intercept column if applicable)
282/// * `beta0` - Intercept on original scale
283/// * `beta` - Slope coefficients on original scale
284///
285/// # Returns
286///
287/// Predictions for each row in x_new.
288pub fn predict(x_new: &Matrix, beta0: f64, beta: &[f64]) -> Vec<f64> {
289    let n = x_new.rows;
290    let p = x_new.cols;
291
292    let mut predictions = vec![0.0; n];
293
294    // Determine if we have an intercept column (first column is typically all ones)
295    // If beta has one fewer element than columns, assume first column is intercept
296    let has_intercept_col = beta.len() == p - 1;
297    let start_col = if has_intercept_col { 1 } else { 0 };
298
299    for i in 0..n {
300        let mut sum = beta0;
301        for (j, beta_j) in beta.iter().enumerate() {
302            let col = start_col + j;
303            if col < p {
304                sum += x_new.get(i, col) * beta_j;
305            }
306        }
307        predictions[i] = sum;
308    }
309
310    predictions
311}
312
313#[cfg(test)]
314mod tests {
315    use super::*;
316
317    #[test]
318    fn test_standardize_xy_with_intercept() {
319        // Simple test data
320        let x_data = vec![
321            1.0, 2.0, 3.0,
322            1.0, 4.0, 6.0,
323            1.0, 6.0, 9.0,
324        ];
325        let x = Matrix::new(3, 3, x_data);
326        let y = vec![3.0, 5.0, 7.0];
327
328        let options = StandardizeOptions {
329            intercept: true,
330            standardize_x: true,
331            standardize_y: false,
332        };
333
334        let (x_std, y_std, info) = standardize_xy(&x, &y, &options);
335
336        // First column (intercept) should be unchanged
337        assert_eq!(x_std.get(0, 0), 1.0);
338        assert_eq!(x_std.get(1, 0), 1.0);
339        assert_eq!(x_std.get(2, 0), 1.0);
340
341        // y should NOT be centered (standardize_y = false)
342        for i in 0..3 {
343            assert!((y_std[i] - y[i]).abs() < 1e-10);
344        }
345
346        // x_mean should capture the column means
347        assert_eq!(info.x_mean[0], 0.0); // Intercept column
348        assert!((info.x_mean[1] - 4.0).abs() < 1e-10);
349        assert!((info.x_mean[2] - 6.0).abs() < 1e-10);
350    }
351
352    #[test]
353    fn test_unstandardize_coefficients() {
354        // Create a simple standardization scenario
355        let x_mean = vec![0.0, 4.0, 6.0];
356        let x_scale = vec![1.0, 2.0, 3.0];
357        let y_mean = 5.0;
358        let y_scale = Some(2.0);
359
360        let info = StandardizationInfo {
361            x_mean: x_mean.clone(),
362            x_scale: x_scale.clone(),
363            y_mean,
364            y_scale,
365            intercept: true,
366            standardized_x: true,
367            standardized_y: true,
368        };
369
370        // Coefficients in standardized space: [intercept=0, beta1=1, beta2=2]
371        let beta_std = vec![0.0, 1.0, 2.0];
372
373        let (beta0, beta_orig) = unstandardize_coefficients(&beta_std, &info);
374
375        // Check unstandardization
376        // beta_orig[1] = (y_scale * beta_std[1]) / x_scale[1] = (2 * 1) / 2 = 1
377        assert!((beta_orig[1] - 1.0).abs() < 1e-10);
378        // beta_orig[2] = (y_scale * beta_std[2]) / x_scale[2] = (2 * 2) / 3 = 4/3
379        assert!((beta_orig[2] - 4.0/3.0).abs() < 1e-10);
380
381        // beta0 = y_mean - sum(x_mean[j] * beta_orig[j])
382        //      = 5 - (4 * 1 + 6 * 4/3) = 5 - 4 - 8 = -7
383        assert!((beta0 - (-7.0)).abs() < 1e-10);
384    }
385
386    #[test]
387    fn test_predict() {
388        let x_data = vec![
389            1.0, 2.0, 3.0,
390            1.0, 4.0, 6.0,
391        ];
392        let x = Matrix::new(2, 3, x_data);
393
394        // beta0 = 1, beta = [2.0, 3.0]
395        let beta0 = 1.0;
396        let beta = vec![2.0, 3.0];
397
398        let preds = predict(&x, beta0, &beta);
399
400        // pred[0] = 1 + 2*2 + 3*3 = 1 + 4 + 9 = 14
401        assert!((preds[0] - 14.0).abs() < 1e-10);
402        // pred[1] = 1 + 2*4 + 3*6 = 1 + 8 + 18 = 27
403        assert!((preds[1] - 27.0).abs() < 1e-10);
404    }
405}