linreg_core/regularized/preprocess.rs
1//! Data preprocessing for regularized regression.
2//!
3//! This module provides standardization utilities that match glmnet's behavior:
4//!
5//! - Predictors are centered and scaled (if enabled)
6//! - The intercept column is not penalized, so it's handled specially
7//! - Coefficients can be unstandardized back to the original scale
8//!
9//! # Standardization Convention
10//!
11//! The scaling factor used is `sqrt(sum(x²) / n)`, which gives unit variance
12//! under the 1/n convention (matching glmnet).
13
14use crate::linalg::{vec_mean, Matrix};
15
16/// Information stored during standardization, used to unstandardize coefficients.
17///
18/// This struct captures all the information needed to transform coefficients
19/// from the standardized space back to the original data scale.
20///
21/// # Fields
22///
23/// * `x_mean` - Mean of each predictor column (length p)
24/// * `x_scale` - Scale factor for each predictor column (length p)
25/// * `y_mean` - Mean of response variable
26/// * `y_scale` - Scale factor for response (optional, used for lambda path)
27/// * `intercept` - Whether an intercept term was included
28/// * `standardized_x` - Whether X was standardized
29/// * `standardized_y` - Whether y was standardized
30#[derive(Clone, Debug)]
31pub struct StandardizationInfo {
32 /// Mean of each predictor column
33 pub x_mean: Vec<f64>,
34 /// Scale factor for each predictor column
35 pub x_scale: Vec<f64>,
36 /// Mean of response variable
37 pub y_mean: f64,
38 /// Scale factor for response (for lambda path construction)
39 pub y_scale: Option<f64>,
40 /// Whether an intercept was included
41 pub intercept: bool,
42 /// Whether X was standardized
43 pub standardized_x: bool,
44 /// Whether y was standardized
45 pub standardized_y: bool,
46}
47
48/// Options for standardization.
49///
50/// # Fields
51///
52/// * `intercept` - Whether to include/center an intercept (default: true)
53/// * `standardize_x` - Whether to standardize predictors (default: true)
54/// * `standardize_y` - Whether to standardize response (default: false)
55///
56/// # Note
57///
58/// Setting `standardize_y` to `true` is mainly useful when you want to match
59/// glmnet's lambda sequence exactly. For single-lambda fits, you typically
60/// don't need to standardize y.
61#[derive(Clone, Debug)]
62pub struct StandardizeOptions {
63 /// Whether to include an intercept (and center X)
64 pub intercept: bool,
65 /// Whether to standardize predictor columns
66 pub standardize_x: bool,
67 /// Whether to standardize the response variable
68 pub standardize_y: bool,
69}
70
71impl Default for StandardizeOptions {
72 fn default() -> Self {
73 StandardizeOptions {
74 intercept: true,
75 standardize_x: true,
76 standardize_y: false,
77 }
78 }
79}
80
81/// Standardizes X and y for regularized regression (glmnet-compatible).
82///
83/// This function performs the same standardization as glmnet with
84/// `standardize=TRUE`. The first column of X is assumed to be the intercept
85/// (all ones) and is NOT standardized.
86///
87/// # Arguments
88///
89/// * `x` - Design matrix (n × p). First column should be intercept if `intercept=true`.
90/// * `y` - Response vector (n elements)
91/// * `options` - Standardization options
92///
93/// # Returns
94///
95/// A tuple `(x_std, y_std, info)` where:
96/// - `x_std` is the standardized design matrix
97/// - `y_std` is the (optionally) standardized response
98/// - `info` contains the standardization parameters for unstandardization
99///
100/// # Standardization Details
101///
102/// For the intercept column (first column, if present):
103/// - Not centered (stays as ones)
104/// - Not scaled
105///
106/// For other columns (if `standardize_x=true`):
107/// - Centered: `x_centered = x - mean(x)`
108/// - Scaled: `x_scaled = x_centered / sqrt(sum(x²) / n)`
109///
110/// For y (if `standardize_y=true`):
111/// - Centered: `y_centered = y - mean(y)`
112/// - Scaled: `y_scaled = y_centered / sqrt(sum(y²) / n)`
113pub fn standardize_xy(x: &Matrix, y: &[f64], options: &StandardizeOptions) -> (Matrix, Vec<f64>, StandardizationInfo) {
114 let n = x.rows;
115 let p = x.cols;
116
117 let mut x_std = x.clone();
118 let mut y_std = y.to_vec();
119
120 let mut x_mean = vec![0.0; p];
121 let mut x_scale = vec![1.0; p];
122
123 let y_mean = if options.intercept && !y.is_empty() {
124 vec_mean(y)
125 } else {
126 0.0
127 };
128
129 // Standardize y if requested
130 let y_scale = if options.standardize_y {
131 let y_centered: Vec<f64> = y.iter().map(|&yi| yi - y_mean).collect();
132 let y_var = y_centered.iter().map(|&yi| yi * yi).sum::<f64>() / n as f64;
133 let y_scale_val = y_var.sqrt();
134 if y_scale_val > 0.0 {
135 for yi in y_std.iter_mut() {
136 *yi = (*yi - y_mean) / y_scale_val;
137 }
138 }
139 Some(y_scale_val)
140 } else {
141 None
142 };
143
144 // Standardize X columns
145 // If intercept is present, first column is NOT standardized
146 let start_col = if options.intercept { 1 } else { 0 };
147
148 for j in start_col..p {
149 // Compute column mean
150 let mut col_mean = 0.0;
151 for i in 0..n {
152 col_mean += x_std.get(i, j);
153 }
154 col_mean /= n as f64;
155 x_mean[j] = col_mean;
156
157 if options.standardize_x {
158 // Center the column
159 for i in 0..n {
160 let val = x_std.get(i, j) - col_mean;
161 x_std.set(i, j, val);
162 }
163
164 // Compute scale: sqrt(sum(x²) / n)
165 let mut col_scale_sq = 0.0;
166 for i in 0..n {
167 let val = x_std.get(i, j);
168 col_scale_sq += val * val;
169 }
170 let col_scale = (col_scale_sq / n as f64).sqrt();
171
172 if col_scale > 0.0 {
173 x_scale[j] = col_scale;
174 // Scale the column
175 for i in 0..n {
176 let val = x_std.get(i, j) / col_scale;
177 x_std.set(i, j, val);
178 }
179 }
180 } else {
181 // Just center, don't scale
182 x_scale[j] = 1.0;
183 }
184 }
185
186 // If intercept column exists, set its scale to 1.0 (not penalized)
187 if options.intercept && p > 0 {
188 x_scale[0] = 1.0;
189 x_mean[0] = 0.0; // Intercept column has no "mean" to subtract
190 }
191
192 let info = StandardizationInfo {
193 x_mean,
194 x_scale,
195 y_mean,
196 y_scale,
197 intercept: options.intercept,
198 standardized_x: options.standardize_x,
199 standardized_y: options.standardize_y,
200 };
201
202 (x_std, y_std, info)
203}
204
205/// Unstandardizes coefficients from the standardized space back to original scale.
206///
207/// This reverses the standardization transformation to get coefficients that
208/// can be applied to the original (unscaled) data.
209///
210/// # Arguments
211///
212/// * `beta_std` - Coefficients in standardized space (length p)
213/// * `info` - Standardization information from [`standardize_xy`]
214///
215/// # Returns
216///
217/// A tuple `(beta0, beta_original)` where:
218/// - `beta0` is the intercept on the original scale
219/// - `beta_original` are the slope coefficients on the original scale
220///
221/// # Unstandardization Formula
222///
223/// For non-intercept coefficients:
224/// ```text
225/// β_original[j] = (y_scale * β_std[j]) / x_scale[j]
226/// ```
227///
228/// For the intercept:
229/// ```text
230/// β₀ = y_mean - Σⱼ x_mean[j] * β_original[j]
231/// ```
232///
233/// If y was not standardized, `y_scale = 1`.
234/// If X was not standardized, `x_scale[j] = 1`.
235///
236/// # Note
237///
238/// If `intercept=true` in the info, `beta_std[0]` is assumed to be the intercept
239/// coefficient (which is already 0 in the standardized space since X was centered).
240pub fn unstandardize_coefficients(beta_std: &[f64], info: &StandardizationInfo) -> (f64, Vec<f64>) {
241 let p = beta_std.len();
242 let mut beta_original = vec![0.0; p];
243 let y_scale = info.y_scale.unwrap_or(1.0);
244
245 // Handle intercept: if intercept was used, beta_std[0] is the intercept
246 // In standardized space with centered X, the intercept should be y_mean
247 // But we compute it properly from the formula
248
249 let start_idx = if info.intercept { 1 } else { 0 };
250
251 // Unstandardize non-intercept coefficients
252 for j in start_idx..p {
253 beta_original[j] = (y_scale * beta_std[j]) / info.x_scale[j];
254 }
255
256 // Compute intercept on original scale
257 let beta0 = if info.intercept {
258 let mut sum = 0.0;
259 for j in 1..p {
260 sum += info.x_mean[j] * beta_original[j];
261 }
262 info.y_mean - sum
263 } else {
264 0.0
265 };
266
267 // If intercept was in beta_std, store it separately
268 let intercept_value = if info.intercept {
269 beta0
270 } else {
271 0.0
272 };
273
274 (intercept_value, beta_original)
275}
276
277/// Computes predictions using unstandardized coefficients.
278///
279/// # Arguments
280///
281/// * `x_new` - New data matrix (n_new × p, with intercept column if applicable)
282/// * `beta0` - Intercept on original scale
283/// * `beta` - Slope coefficients on original scale
284///
285/// # Returns
286///
287/// Predictions for each row in x_new.
288pub fn predict(x_new: &Matrix, beta0: f64, beta: &[f64]) -> Vec<f64> {
289 let n = x_new.rows;
290 let p = x_new.cols;
291
292 let mut predictions = vec![0.0; n];
293
294 // Determine if we have an intercept column (first column is typically all ones)
295 // If beta has one fewer element than columns, assume first column is intercept
296 let has_intercept_col = beta.len() == p - 1;
297 let start_col = if has_intercept_col { 1 } else { 0 };
298
299 for i in 0..n {
300 let mut sum = beta0;
301 for (j, beta_j) in beta.iter().enumerate() {
302 let col = start_col + j;
303 if col < p {
304 sum += x_new.get(i, col) * beta_j;
305 }
306 }
307 predictions[i] = sum;
308 }
309
310 predictions
311}
312
313#[cfg(test)]
314mod tests {
315 use super::*;
316
317 #[test]
318 fn test_standardize_xy_with_intercept() {
319 // Simple test data
320 let x_data = vec![
321 1.0, 2.0, 3.0,
322 1.0, 4.0, 6.0,
323 1.0, 6.0, 9.0,
324 ];
325 let x = Matrix::new(3, 3, x_data);
326 let y = vec![3.0, 5.0, 7.0];
327
328 let options = StandardizeOptions {
329 intercept: true,
330 standardize_x: true,
331 standardize_y: false,
332 };
333
334 let (x_std, y_std, info) = standardize_xy(&x, &y, &options);
335
336 // First column (intercept) should be unchanged
337 assert_eq!(x_std.get(0, 0), 1.0);
338 assert_eq!(x_std.get(1, 0), 1.0);
339 assert_eq!(x_std.get(2, 0), 1.0);
340
341 // y should NOT be centered (standardize_y = false)
342 for i in 0..3 {
343 assert!((y_std[i] - y[i]).abs() < 1e-10);
344 }
345
346 // x_mean should capture the column means
347 assert_eq!(info.x_mean[0], 0.0); // Intercept column
348 assert!((info.x_mean[1] - 4.0).abs() < 1e-10);
349 assert!((info.x_mean[2] - 6.0).abs() < 1e-10);
350 }
351
352 #[test]
353 fn test_unstandardize_coefficients() {
354 // Create a simple standardization scenario
355 let x_mean = vec![0.0, 4.0, 6.0];
356 let x_scale = vec![1.0, 2.0, 3.0];
357 let y_mean = 5.0;
358 let y_scale = Some(2.0);
359
360 let info = StandardizationInfo {
361 x_mean: x_mean.clone(),
362 x_scale: x_scale.clone(),
363 y_mean,
364 y_scale,
365 intercept: true,
366 standardized_x: true,
367 standardized_y: true,
368 };
369
370 // Coefficients in standardized space: [intercept=0, beta1=1, beta2=2]
371 let beta_std = vec![0.0, 1.0, 2.0];
372
373 let (beta0, beta_orig) = unstandardize_coefficients(&beta_std, &info);
374
375 // Check unstandardization
376 // beta_orig[1] = (y_scale * beta_std[1]) / x_scale[1] = (2 * 1) / 2 = 1
377 assert!((beta_orig[1] - 1.0).abs() < 1e-10);
378 // beta_orig[2] = (y_scale * beta_std[2]) / x_scale[2] = (2 * 2) / 3 = 4/3
379 assert!((beta_orig[2] - 4.0/3.0).abs() < 1e-10);
380
381 // beta0 = y_mean - sum(x_mean[j] * beta_orig[j])
382 // = 5 - (4 * 1 + 6 * 4/3) = 5 - 4 - 8 = -7
383 assert!((beta0 - (-7.0)).abs() < 1e-10);
384 }
385
386 #[test]
387 fn test_predict() {
388 let x_data = vec![
389 1.0, 2.0, 3.0,
390 1.0, 4.0, 6.0,
391 ];
392 let x = Matrix::new(2, 3, x_data);
393
394 // beta0 = 1, beta = [2.0, 3.0]
395 let beta0 = 1.0;
396 let beta = vec![2.0, 3.0];
397
398 let preds = predict(&x, beta0, &beta);
399
400 // pred[0] = 1 + 2*2 + 3*3 = 1 + 4 + 9 = 14
401 assert!((preds[0] - 14.0).abs() < 1e-10);
402 // pred[1] = 1 + 2*4 + 3*6 = 1 + 8 + 18 = 27
403 assert!((preds[1] - 27.0).abs() < 1e-10);
404 }
405}