linreg_core/regularized/preprocess.rs
1//! Data preprocessing for regularized regression.
2//!
3//! This module provides standardization utilities that match glmnet output behavior:
4//!
5//! - Predictors are centered and scaled (if enabled)
6//! - The intercept column is not penalized, so it's handled specially
7//! - Coefficients can be unstandardized back to the original scale
8//! - Observation weights are supported for weighted regression
9//!
10//! # Standardization Convention
11//!
12//! The scaling factor used is `sqrt(sum(x²) / n)`, which gives unit variance
13//! under the 1/n convention (matching the glmnet paper).
14//!
15//! # Weighted Standardization
16//!
17//! When weights are provided, they are first normalized to sum to 1:
18//! `weights_normalized = w / sum(w)`. Then weighted means and variances are computed.
19
20use crate::linalg::Matrix;
21
22/// Information stored during standardization, used to unstandardize coefficients.
23///
24/// This struct captures all the information needed to transform coefficients
25/// from the standardized space back to the original data scale.
26///
27/// # Fields
28///
29/// * `x_mean` - Mean of each predictor column (length p)
30/// * `x_scale` - Scale factor for each predictor column (length p)
31/// * `y_mean` - Mean of response variable
32/// * `y_scale` - Scale factor for response (optional, used for lambda path)
33/// * `intercept` - Whether an intercept term was included
34/// * `standardized_x` - Whether X was standardized
35/// * `standardized_y` - Whether y was standardized
36///
37/// # Example
38///
39/// ```
40/// # use linreg_core::regularized::preprocess::StandardizationInfo;
41/// let info = StandardizationInfo {
42/// x_mean: vec![0.0, 5.0],
43/// x_scale: vec![1.0, 2.0],
44/// column_squared_norms: vec![1.0, 1.0],
45/// y_mean: 10.0,
46/// y_scale: Some(3.0),
47/// y_scale_before_sqrt_weights_normalized: Some(3.0),
48/// intercept: true,
49/// standardized_x: true,
50/// standardized_y: false,
51/// };
52///
53/// assert_eq!(info.x_mean.len(), 2);
54/// assert!(info.intercept);
55/// ```
56#[derive(Clone, Debug)]
57pub struct StandardizationInfo {
58 /// Mean of each predictor column
59 pub x_mean: Vec<f64>,
60 /// Scale factor for each predictor column
61 pub x_scale: Vec<f64>,
62 /// Squared norm of each predictor column after standardization.
63 /// This is used in the coordinate descent update denominator.
64 /// - With intercept and standardize: column_squared_norms\[j\] = 1.0 (unit norm after centering)
65 /// - Without intercept and standardize: column_squared_norms\[j\] = 1.0 + x_squared_mean/x_centered_variance (glmnet formula)
66 /// - Without standardize: column_squared_norms\[j\] = ||x_j||^2 (actual squared norm)
67 pub column_squared_norms: Vec<f64>,
68 /// Mean of response variable
69 pub y_mean: f64,
70 /// Scale factor for response (for lambda path construction)
71 /// This is the norm AFTER sqrt_weights_normalized transformation and centering: sqrt(sum((sqrt_weights_normalized*(y-ym))^2))
72 pub y_scale: Option<f64>,
73 /// Scale factor for response BEFORE sqrt_weights_normalized transformation: sqrt(sum((y-ym)^2))
74 /// This is used for lambda scaling between original and standardized data
75 pub y_scale_before_sqrt_weights_normalized: Option<f64>,
76 /// Whether an intercept was included
77 pub intercept: bool,
78 /// Whether X was standardized
79 pub standardized_x: bool,
80 /// Whether y was standardized
81 pub standardized_y: bool,
82}
83
84/// Options for standardization.
85///
86/// # Fields
87///
88/// * `intercept` - Whether to include/center an intercept (default: true)
89/// * `standardize_x` - Whether to standardize predictors (default: true)
90/// * `standardize_y` - Whether to standardize response (default: false)
91/// * `weights` - Optional observation weights (default: None)
92/// If provided, weights are normalized to sum to 1 before use.
93///
94/// # Note
95///
96/// Setting `standardize_y` to `true` is mainly useful when you want to match
97/// glmnet's lambda sequence exactly. For single-lambda fits, you typically
98/// don't need to standardize y.
99///
100/// # Example
101///
102/// ```
103/// # use linreg_core::regularized::preprocess::StandardizeOptions;
104/// let opts = StandardizeOptions {
105/// intercept: true,
106/// standardize_x: true,
107/// standardize_y: false,
108/// weights: None,
109/// };
110///
111/// assert!(opts.intercept);
112/// assert!(opts.standardize_x);
113/// ```
114#[derive(Clone, Debug)]
115pub struct StandardizeOptions {
116 /// Whether to include an intercept (and center X)
117 pub intercept: bool,
118 /// Whether to standardize predictor columns
119 pub standardize_x: bool,
120 /// Whether to standardize the response variable
121 pub standardize_y: bool,
122 /// Optional observation weights (normalized to sum to 1)
123 pub weights: Option<Vec<f64>>,
124}
125
126impl Default for StandardizeOptions {
127 fn default() -> Self {
128 StandardizeOptions {
129 intercept: true,
130 standardize_x: true,
131 standardize_y: false,
132 weights: None,
133 }
134 }
135}
136
137/// Standardizes X and y for regularized regression (glmnet-compatible).
138///
139/// This function performs the same standardization as glmnet with
140/// `standardize=TRUE`. The first column of X is assumed to be the intercept
141/// (all ones) and is NOT standardized.
142///
143/// # Arguments
144///
145/// * `x` - Design matrix (n × p). First column should be intercept if `intercept=true`.
146/// * `y` - Response vector (n elements)
147/// * `options` - Standardization options (including optional observation weights)
148///
149/// # Returns
150///
151/// A tuple `(x_standardized, y_standardized, info)` where:
152/// - `x_standardized` is the standardized design matrix
153/// - `y_standardized` is the (optionally) standardized response
154/// - `info` contains the standardization parameters for unstandardization
155///
156/// # Standardization Details
157///
158/// ## Unweighted case:
159/// For the intercept column (first column, if present):
160/// - Not centered (stays as ones)
161/// - Not scaled
162///
163/// For other columns (if `standardize_x=true`):
164/// - Centered: `x_centered = x - mean(x)`
165/// - Scaled: `x_scaled = x_centered / sqrt(sum(x²))`
166///
167/// For y (if `standardize_y=true`):
168/// - Centered: `y_centered = y - mean(y)`
169/// - Scaled: `y_scaled = y_centered / sqrt(sum(y²))`
170///
171/// ## Weighted case:
172/// Weights are first normalized: `weights_normalized = w / sum(w)`, then `sqrt_weights_normalized = sqrt(weights_normalized)`
173/// - Weighted mean: `ym = sum(w * y) / sum(w) = sum(weights_normalized * y)`
174/// - Weighted variance: `sum(w * (y - ym)^2) / sum(w)`
175/// - Data transformed by `sqrt_weights_normalized`: `y_new = sqrt_weights_normalized * (y - ym)`, then scaled
176#[allow(clippy::needless_range_loop)]
177pub fn standardize_xy(
178 x: &Matrix,
179 y: &[f64],
180 options: &StandardizeOptions,
181) -> (Matrix, Vec<f64>, StandardizationInfo) {
182 let n = x.rows;
183 let p = x.cols;
184
185 // Validate weights if provided
186 if let Some(ref w) = options.weights {
187 if w.len() != n {
188 return (
189 Matrix::new(n, p, vec![0.0; n * p]),
190 vec![0.0; n],
191 StandardizationInfo {
192 x_mean: vec![0.0; p],
193 x_scale: vec![1.0; p],
194 column_squared_norms: vec![0.0; p],
195 y_mean: 0.0,
196 y_scale: None,
197 y_scale_before_sqrt_weights_normalized: None,
198 intercept: options.intercept,
199 standardized_x: options.standardize_x,
200 standardized_y: options.standardize_y,
201 },
202 );
203 }
204 if w.iter().any(|&wi| wi < 0.0) {
205 panic!("Weights must be non-negative");
206 }
207 }
208
209 // Prepare normalized weights and sqrt(weights)
210 // w = w / sum(w) then sqrt_weights_normalized = sqrt(w)
211 let (weights_normalized, sqrt_weights_normalized): (Vec<f64>, Vec<f64>) = if let Some(ref w) = options.weights {
212 let w_sum: f64 = w.iter().sum();
213 if w_sum > 0.0 {
214 let weights_normalized_vec: Vec<f64> = w.iter().map(|&wi| wi / w_sum).collect();
215 let sqrt_weights_normalized_vec: Vec<f64> = weights_normalized_vec.iter().map(|&wi| wi.sqrt()).collect();
216 (weights_normalized_vec, sqrt_weights_normalized_vec)
217 } else {
218 (vec![0.0; n], vec![0.0; n])
219 }
220 } else {
221 // No weights: use uniform weights
222 let w_uniform = vec![1.0 / n as f64; n];
223 let sqrt_weights_normalized_uniform = vec![1.0 / (n as f64).sqrt(); n];
224 (w_uniform, sqrt_weights_normalized_uniform)
225 };
226
227 let mut x_standardized = x.clone();
228 let mut y_standardized = y.to_vec();
229
230 let mut x_mean = vec![0.0; p];
231 let mut x_scale = vec![1.0; p];
232 let mut column_squared_norms = vec![0.0; p]; // Column squared norms for coordinate descent
233
234 let y_mean = if options.intercept && !y.is_empty() {
235 // Weighted mean: ym = sum(w * y)
236 weights_normalized.iter().zip(y.iter()).map(|(&w, &yi)| w * yi).sum()
237 } else {
238 0.0
239 };
240
241 // GLMNET: y is ALWAYS scaled to unit norm!
242 // This is critical for correct lambda_max computation
243 let (y_scale, y_scale_before_sqrt_weights_normalized) = if options.intercept {
244 // WITH INTERCEPT: Center y, then scale to unit norm
245 // First compute y_scale_before_sqrt_weights_normalized (centered but not sqrt_weights_normalized-transformed)
246 let y_centered: Vec<f64> = y.iter().map(|&yi| yi - y_mean).collect();
247 let y_ss_before_sqrt_weights_normalized: f64 = y_centered.iter().map(|&yi| yi * yi).sum();
248 let y_scale_before_sqrt_weights_normalized_val = y_ss_before_sqrt_weights_normalized.sqrt();
249
250 // Center y: y_new = sqrt_weights_normalized * (y - ym)
251 for (yi, &sqrt_weight) in y_standardized.iter_mut().zip(&sqrt_weights_normalized) {
252 *yi = sqrt_weight * (*yi - y_mean);
253 }
254
255 // Scale to unit norm (GLMNET always does this!)
256 let y_ss: f64 = y_standardized.iter().map(|&yi| yi * yi).sum();
257 let y_scale_val = y_ss.sqrt();
258 if y_scale_val > 0.0 {
259 for yi in y_standardized.iter_mut() {
260 *yi /= y_scale_val;
261 }
262 }
263 (Some(y_scale_val), Some(y_scale_before_sqrt_weights_normalized_val))
264 } else {
265 // WITHOUT INTERCEPT: Don't center y, but DO scale to unit norm (GLMNET output behavior!)
266 // y_new = sqrt_weights_normalized * y, then y = y / ||y||
267 for (yi, &sqrt_weight) in y_standardized.iter_mut().zip(&sqrt_weights_normalized) {
268 *yi *= sqrt_weight;
269 }
270 let y_ss: f64 = y_standardized.iter().map(|&yi| yi * yi).sum();
271 let y_scale_val = y_ss.sqrt();
272 if y_scale_val > 0.0 {
273 for yi in y_standardized.iter_mut() {
274 *yi /= y_scale_val;
275 }
276 }
277 (Some(y_scale_val), Some(y_scale_val)) // y_scale_before_sqrt_weights_normalized = y_scale when no centering
278 };
279
280 // Standardize X columns
281 // If intercept is present, first column is NOT standardized (it's the intercept column)
282 let first_penalized_column_index = if options.intercept { 1 } else { 0 };
283
284 if options.intercept {
285 // WITH INTERCEPT (intercept column not standardized)
286 for j in first_penalized_column_index..p {
287 // Compute weighted column mean and center
288 let col_mean: f64 = (0..n)
289 .map(|i| x_standardized.get(i, j) * weights_normalized[i])
290 .sum();
291 x_mean[j] = col_mean;
292
293 // Center the column and apply sqrt_weights_normalized transformation
294 // x_new = sqrt_weights_normalized * (x - xm)
295 for i in 0..n {
296 let val = sqrt_weights_normalized[i] * (x_standardized.get(i, j) - col_mean);
297 x_standardized.set(i, j, val);
298 }
299
300 // Compute squared norm
301 let col_squared_norm_val: f64 = (0..n)
302 .map(|i| {
303 let val = x_standardized.get(i, j);
304 val * val
305 })
306 .sum();
307
308 if options.standardize_x {
309 // Scale to unit norm
310 let col_scale = col_squared_norm_val.sqrt();
311 if col_scale > 0.0 {
312 for i in 0..n {
313 let val = x_standardized.get(i, j) / col_scale;
314 x_standardized.set(i, j, val);
315 }
316 x_scale[j] = col_scale;
317 column_squared_norms[j] = 1.0; // Unit norm
318 }
319 } else {
320 // No standardization - column_squared_norms stays as the actual squared norm
321 column_squared_norms[j] = col_squared_norm_val;
322 x_scale[j] = 1.0;
323 }
324 }
325 } else {
326 // WITHOUT INTERCEPT (no centering)
327 for j in first_penalized_column_index..p {
328 x_mean[j] = 0.0; // No centering
329
330 // Apply sqrt_weights_normalized transformation
331 for i in 0..n {
332 let val = sqrt_weights_normalized[i] * x_standardized.get(i, j);
333 x_standardized.set(i, j, val);
334 }
335
336 // Compute squared norm after sqrt_weights_normalized transformation
337 let col_squared_norm_val: f64 = (0..n)
338 .map(|i| {
339 let val = x_standardized.get(i, j);
340 val * val
341 })
342 .sum();
343
344 if options.standardize_x {
345 // GLMNET special formula for no-intercept case:
346 // x_squared_mean = dot_product(sqrt_weights_normalized, x)^2 (squared mean)
347 // x_centered_variance = col_squared_norm - x_squared_mean (variance-like quantity)
348 // xs = sqrt(x_centered_variance)
349 // column_squared_norms_final = 1.0 + x_squared_mean / x_centered_variance
350 let x_squared_mean: f64 = (0..n)
351 .map(|i| sqrt_weights_normalized[i] * x_standardized.get(i, j))
352 .sum::<f64>().powi(2);
353 let x_centered_variance = col_squared_norm_val - x_squared_mean;
354
355 if x_centered_variance > 0.0 {
356 let col_scale = x_centered_variance.sqrt();
357 // Scale by col_scale (NOT by sqrt(col_squared_norm_val))
358 for i in 0..n {
359 let val = x_standardized.get(i, j) / col_scale;
360 x_standardized.set(i, j, val);
361 }
362 x_scale[j] = col_scale;
363 column_squared_norms[j] = 1.0 + x_squared_mean / x_centered_variance; // GLMNET formula
364 } else {
365 column_squared_norms[j] = 1.0;
366 x_scale[j] = 1.0;
367 }
368 } else {
369 // No standardization
370 column_squared_norms[j] = col_squared_norm_val;
371 x_scale[j] = 1.0;
372 }
373 }
374 }
375
376 // If intercept column exists, set its scale to 1.0 (not penalized)
377 if options.intercept && p > 0 {
378 x_scale[0] = 1.0;
379 x_mean[0] = 0.0; // Intercept column has no "mean" to subtract
380 column_squared_norms[0] = 1.0; // Intercept column is not penalized
381 }
382
383 let info = StandardizationInfo {
384 x_mean,
385 x_scale,
386 column_squared_norms,
387 y_mean,
388 y_scale,
389 y_scale_before_sqrt_weights_normalized,
390 intercept: options.intercept,
391 standardized_x: options.standardize_x,
392 standardized_y: options.standardize_y,
393 };
394
395 (x_standardized, y_standardized, info)
396}
397
398/// Unstandardizes coefficients from the standardized space back to original scale.
399///
400/// This reverses the standardization transformation to get coefficients that
401/// can be applied to the original (unscaled) data.
402///
403/// # Arguments
404///
405/// * `coefficients_standardized` - Coefficients in standardized space (length p)
406/// * `info` - Standardization information from [`standardize_xy`]
407///
408/// # Returns
409///
410/// A tuple `(beta0, beta_slopes)` where:
411/// - `beta0` is the intercept on the original scale
412/// - `beta_slopes` are the slope coefficients only (excluding intercept column coefficient)
413///
414/// # Unstandardization Formula
415///
416/// For non-intercept coefficients:
417/// ```text
418/// β_original\[j\] = (y_scale * β_std\[j\]) / x_scale\[j\]
419/// ```
420///
421/// For the intercept:
422/// ```text
423/// β₀ = y_mean - Σⱼ x_mean\[j\] * β_original\[j\]
424/// ```
425///
426/// If y was not standardized, `y_scale = 1`.
427/// If X was not standardized, `x_scale\[j\] = 1`.
428///
429/// # Note
430///
431/// If `intercept=true` in the info, `coefficients_standardized[0]` is assumed to be the intercept
432/// coefficient (which is already 0 in the standardized space since X was centered).
433/// The returned `beta_slopes` will NOT include this zeroed coefficient - only actual
434/// slope coefficients are returned.
435///
436/// # Example
437///
438/// ```
439/// # use linreg_core::regularized::preprocess::{unstandardize_coefficients, StandardizationInfo};
440/// let info = StandardizationInfo {
441/// x_mean: vec![0.0, 5.0],
442/// x_scale: vec![1.0, 2.0],
443/// column_squared_norms: vec![1.0, 1.0],
444/// y_mean: 10.0,
445/// y_scale: Some(3.0),
446/// y_scale_before_sqrt_weights_normalized: Some(3.0),
447/// intercept: true,
448/// standardized_x: true,
449/// standardized_y: false,
450/// };
451///
452/// // Standardized coefficients: [intercept=0, slope1=2.0]
453/// let coefficients_standardized = vec![0.0, 2.0];
454/// let (beta0, beta_slopes) = unstandardize_coefficients(&coefficients_standardized, &info);
455///
456/// // slope_original = (y_scale * slope_std) / x_scale[1]
457/// // = (3.0 * 2.0) / 2.0 = 3.0
458/// assert!((beta_slopes[0] - 3.0).abs() < 0.01);
459/// ```
460#[allow(clippy::needless_range_loop)]
461pub fn unstandardize_coefficients(coefficients_standardized: &[f64], info: &StandardizationInfo) -> (f64, Vec<f64>) {
462 let p = coefficients_standardized.len();
463 let y_scale = info.y_scale.unwrap_or(1.0);
464
465 // Determine where slope coefficients start in coefficients_standardized
466 let start_idx = if info.intercept { 1 } else { 0 };
467 let n_slopes = p - start_idx;
468
469 // Unstandardize slope coefficients only (exclude intercept column coefficient)
470 // NOTE: X is ALWAYS standardized for the solver, so we always apply the unstandardization formula.
471 // The user's `standardize_x` option doesn't affect the internal computation, only the
472 // interpretation of results.
473 let mut beta_slopes = vec![0.0; n_slopes];
474 for j in start_idx..p {
475 let slope_idx = j - start_idx;
476 // Standard unstandardization: beta_original = (y_scale * coefficients_standardized) / x_scale
477 // This converts from the standardized space back to original data scale
478 beta_slopes[slope_idx] = (y_scale * coefficients_standardized[j]) / info.x_scale[j];
479 }
480
481 // Compute intercept on original scale
482 // beta0 = y_mean - sum(x_mean[j] * beta_slopes[j-1]) for j in 1..p
483 let beta0 = if info.intercept {
484 let mut sum = 0.0;
485 for j in 1..p {
486 sum += info.x_mean[j] * beta_slopes[j - 1];
487 }
488 info.y_mean - sum
489 } else {
490 0.0
491 };
492
493 (beta0, beta_slopes)
494}
495
496/// Computes predictions using unstandardized coefficients.
497///
498/// # Arguments
499///
500/// * `x_new` - New data matrix (n_new × p, with intercept column if applicable)
501/// * `beta0` - Intercept on original scale
502/// * `beta` - Slope coefficients on original scale (does NOT include intercept column coefficient)
503///
504/// # Returns
505///
506/// Predictions for each row in x_new.
507///
508/// # Note
509///
510/// If `x_new` has an intercept column (first column of all ones), `beta` should have
511/// `p - 1` elements corresponding to the non-intercept columns. If `x_new` has no
512/// intercept column, `beta` should have `p` elements.
513///
514/// # Example
515///
516/// ```
517/// # use linreg_core::regularized::preprocess::predict;
518/// # use linreg_core::linalg::Matrix;
519/// // X matrix with intercept: [[1, 2], [1, 3], [1, 4]]
520/// let x_new = Matrix::new(3, 2, vec![1.0, 2.0, 1.0, 3.0, 1.0, 4.0]);
521/// let beta0 = 1.0;
522/// let beta = vec![2.0]; // One slope coefficient
523///
524/// // predictions[i] = 1.0 + 2.0 * x[i,1]
525/// let preds = predict(&x_new, beta0, &beta);
526/// assert_eq!(preds, vec![5.0, 7.0, 9.0]);
527/// ```
528#[allow(clippy::needless_range_loop)]
529pub fn predict(x_new: &Matrix, beta0: f64, beta: &[f64]) -> Vec<f64> {
530 let n = x_new.rows;
531 let p = x_new.cols;
532
533 let mut predictions = vec![0.0; n];
534
535 // Determine if there's an intercept column based on beta length
536 // If beta has one fewer element than columns, first column is intercept
537 let has_intercept_col = beta.len() == p - 1;
538 let first_penalized_column_index = if has_intercept_col { 1 } else { 0 };
539
540 for i in 0..n {
541 let mut sum = beta0;
542 for (j, &beta_j) in beta.iter().enumerate() {
543 let col = first_penalized_column_index + j;
544 if col < p {
545 sum += x_new.get(i, col) * beta_j;
546 }
547 }
548 predictions[i] = sum;
549 }
550
551 predictions
552}
553
554#[cfg(test)]
555mod tests {
556 use super::*;
557
558 #[test]
559 fn test_standardize_xy_with_intercept() {
560 // Simple test data
561 let x_data = vec![1.0, 2.0, 3.0, 1.0, 4.0, 6.0, 1.0, 6.0, 9.0];
562 let x = Matrix::new(3, 3, x_data);
563 let y = vec![3.0, 5.0, 7.0];
564
565 let options = StandardizeOptions {
566 intercept: true,
567 standardize_x: true,
568 standardize_y: false, // Note: y is still scaled to unit norm by glmnet convention
569 weights: None,
570 };
571
572 let (x_standardized, y_standardized, info) = standardize_xy(&x, &y, &options);
573
574 // First column (intercept) should be unchanged
575 assert_eq!(x_standardized.get(0, 0), 1.0);
576 assert_eq!(x_standardized.get(1, 0), 1.0);
577 assert_eq!(x_standardized.get(2, 0), 1.0);
578
579 // GLMNET: y is ALWAYS scaled to unit norm
580 // y_centered = y - y_mean = [-2, 0, 2]
581 // sqrt_weights_normalized-transform: y_sqrt_weights = sqrt_weights_normalized * y_centered = [-2/sqrt(3), 0, 2/sqrt(3)]
582 // Scale to unit norm: y_standardized = y_sqrt_weights / ||y_sqrt_weights|| = [-1/sqrt(2), 0, 1/sqrt(2)]
583 let inv_sqrt2 = 1.0 / (2.0_f64).sqrt();
584 assert!((y_standardized[0] - (-inv_sqrt2)).abs() < 1e-10);
585 assert!((y_standardized[1] - 0.0).abs() < 1e-10);
586 assert!((y_standardized[2] - inv_sqrt2).abs() < 1e-10);
587
588 // x_mean should capture the column means
589 assert_eq!(info.x_mean[0], 0.0); // Intercept column
590 assert!((info.x_mean[1] - 4.0).abs() < 1e-10);
591 assert!((info.x_mean[2] - 6.0).abs() < 1e-10);
592 }
593
594 #[test]
595 fn test_unstandardize_coefficients() {
596 // Create a simple standardization scenario
597 let x_mean = vec![0.0, 4.0, 6.0];
598 let x_scale = vec![1.0, 2.0, 3.0];
599 let column_squared_norms = vec![1.0, 1.0, 1.0]; // Unit norm after standardization
600 let y_mean = 5.0;
601 let y_scale = Some(2.0);
602
603 let info = StandardizationInfo {
604 x_mean: x_mean.clone(),
605 x_scale: x_scale.clone(),
606 column_squared_norms,
607 y_mean,
608 y_scale,
609 y_scale_before_sqrt_weights_normalized: None,
610 intercept: true,
611 standardized_x: true,
612 standardized_y: true,
613 };
614
615 // Coefficients in standardized space: [intercept=0, beta1=1, beta2=2]
616 let coefficients_standardized = vec![0.0, 1.0, 2.0];
617
618 let (beta0, beta_slopes) = unstandardize_coefficients(&coefficients_standardized, &info);
619
620 // Check unstandardization - beta_slopes now only contains slope coefficients
621 // beta_slopes[0] = (y_scale * coefficients_standardized[1]) / x_scale[1] = (2 * 1) / 2 = 1
622 assert!((beta_slopes[0] - 1.0).abs() < 1e-10);
623 // beta_slopes[1] = (y_scale * coefficients_standardized[2]) / x_scale[2] = (2 * 2) / 3 = 4/3
624 assert!((beta_slopes[1] - 4.0 / 3.0).abs() < 1e-10);
625
626 // beta0 = y_mean - sum(x_mean[j] * beta_slopes[j-1])
627 // = 5 - (4 * 1 + 6 * 4/3) = 5 - 4 - 8 = -7
628 assert!((beta0 - (-7.0)).abs() < 1e-10);
629
630 // Verify beta_slopes has the correct length (only slopes, not intercept col coef)
631 assert_eq!(beta_slopes.len(), 2);
632 }
633
634 #[test]
635 fn test_predict() {
636 // X has intercept column (first col all 1s) plus 2 predictors
637 let x_data = vec![1.0, 2.0, 3.0, 1.0, 4.0, 6.0];
638 let x = Matrix::new(2, 3, x_data);
639
640 // beta0 = 1, beta = [2.0, 3.0] (slope coefficients only, no intercept col coef)
641 let beta0 = 1.0;
642 let beta = vec![2.0, 3.0];
643
644 let preds = predict(&x, beta0, &beta);
645
646 // pred[0] = 1 + 2*2 + 3*3 = 1 + 4 + 9 = 14
647 assert!((preds[0] - 14.0).abs() < 1e-10);
648 // pred[1] = 1 + 2*4 + 3*6 = 1 + 8 + 18 = 27
649 assert!((preds[1] - 27.0).abs() < 1e-10);
650 }
651
652 #[test]
653 fn test_weighted_standardize_xy() {
654 // Simple test data
655 let x_data = vec![1.0, 2.0, 3.0, 1.0, 4.0, 6.0, 1.0, 6.0, 9.0];
656 let x = Matrix::new(3, 3, x_data);
657 let y = vec![3.0, 5.0, 7.0];
658
659 // Weights: give more weight to the middle observation
660 let weights = vec![1.0, 2.0, 1.0];
661
662 let options = StandardizeOptions {
663 intercept: true,
664 standardize_x: true,
665 standardize_y: false, // Note: y is still scaled to unit norm by glmnet convention
666 weights: Some(weights),
667 };
668
669 let (x_standardized, y_standardized, info) = standardize_xy(&x, &y, &options);
670
671 // First column (intercept) should be unchanged
672 assert_eq!(x_standardized.get(0, 0), 1.0);
673 assert_eq!(x_standardized.get(1, 0), 1.0);
674 assert_eq!(x_standardized.get(2, 0), 1.0);
675
676 // y_mean should be weighted mean
677 // weights normalized: [1/4, 2/4, 1/4] = [0.25, 0.5, 0.25]
678 // weighted mean: 0.25*3 + 0.5*5 + 0.25*7 = 0.75 + 2.5 + 1.75 = 5.0
679 assert!((info.y_mean - 5.0).abs() < 1e-10);
680
681 // GLMNET: y is ALWAYS scaled to unit norm
682 // y_centered = y - y_mean = [-2, 0, 2]
683 // sqrt_weights_normalized = sqrt([0.25, 0.5, 0.25]) = [0.5, ~0.707, 0.5]
684 // y_sqrt_weights = sqrt_weights_normalized * y_centered = [-1, 0, 1]
685 // sum(y_sqrt_weights^2) = 2, so y_scale = sqrt(2)
686 // y_standardized = y_sqrt_weights / y_scale = [-1/sqrt(2), 0, 1/sqrt(2)]
687 let expected_0 = -1.0 / (2.0_f64).sqrt();
688 assert!((y_standardized[0] - expected_0).abs() < 1e-10);
689 assert!((y_standardized[1] - 0.0).abs() < 1e-10);
690 assert!((y_standardized[2] + expected_0).abs() < 1e-10); // Should be 1/sqrt(2)
691 }
692
693 #[test]
694 fn test_weighted_standardize_uniform_weights() {
695 // Test that uniform weights give same result as no weights
696 let x_data = vec![1.0, 2.0, 3.0, 1.0, 4.0, 6.0];
697 let x = Matrix::new(2, 3, x_data);
698 let y = vec![3.0, 5.0];
699
700 // Uniform weights (should be equivalent to no weights after normalization)
701 let weights = vec![1.0, 1.0];
702
703 let options_with_weights = StandardizeOptions {
704 intercept: true,
705 standardize_x: true,
706 standardize_y: false,
707 weights: Some(weights),
708 };
709
710 let options_no_weights = StandardizeOptions {
711 intercept: true,
712 standardize_x: true,
713 standardize_y: false,
714 weights: None,
715 };
716
717 let (_x_standardized_w, y_standardized_w, info_w) = standardize_xy(&x, &y, &options_with_weights);
718 let (_x_standardized, y_standardized, info) = standardize_xy(&x, &y, &options_no_weights);
719
720 // Results should be the same
721 assert_eq!(info_w.y_mean, info.y_mean);
722 for i in 0..2 {
723 assert!((y_standardized_w[i] - y_standardized[i]).abs() < 1e-10);
724 }
725 }
726
727 #[test]
728 fn test_standardize_xy_weights_dimension_mismatch() {
729 // Test the early return path when weights don't match data dimensions
730 let x_data = vec![1.0, 2.0, 3.0, 1.0, 4.0, 6.0];
731 let x = Matrix::new(2, 3, x_data);
732 let y = vec![3.0, 5.0];
733
734 // Wrong number of weights (3 instead of 2)
735 let weights = vec![1.0, 1.0, 1.0];
736
737 let options = StandardizeOptions {
738 intercept: true,
739 standardize_x: true,
740 standardize_y: false,
741 weights: Some(weights),
742 };
743
744 let (x_standardized, y_standardized, info) = standardize_xy(&x, &y, &options);
745
746 // Should return zero matrices with default info
747 assert_eq!(x_standardized.rows, 2);
748 assert_eq!(x_standardized.cols, 3);
749 assert_eq!(y_standardized, vec![0.0, 0.0]);
750 assert!(!info.standardized_y);
751 assert!(info.intercept);
752 assert!(info.standardized_x);
753 }
754
755 #[test]
756 #[should_panic(expected = "Weights must be non-negative")]
757 fn test_standardize_xy_negative_weights_panics() {
758 let x_data = vec![1.0, 2.0, 3.0, 1.0, 4.0, 6.0];
759 let x = Matrix::new(2, 3, x_data);
760 let y = vec![3.0, 5.0];
761
762 // Negative weight should panic
763 let weights = vec![1.0, -0.5];
764
765 let options = StandardizeOptions {
766 intercept: true,
767 standardize_x: true,
768 standardize_y: false,
769 weights: Some(weights),
770 };
771
772 let _ = standardize_xy(&x, &y, &options);
773 }
774
775 #[test]
776 fn test_standardize_xy_zero_sum_weights() {
777 // Test the zero-sum weights path (lines 217-219)
778 let x_data = vec![1.0, 2.0, 3.0, 1.0, 4.0, 6.0];
779 let x = Matrix::new(2, 3, x_data);
780 let y = vec![3.0, 5.0];
781
782 // All zeros - sum is 0
783 let weights = vec![0.0, 0.0];
784
785 let options = StandardizeOptions {
786 intercept: true,
787 standardize_x: true,
788 standardize_y: false,
789 weights: Some(weights),
790 };
791
792 let (_x_standardized, y_standardized, info) = standardize_xy(&x, &y, &options);
793
794 // With zero weights, y_mean should be 0 and y_standardized should be zeros
795 assert_eq!(info.y_mean, 0.0);
796 assert_eq!(y_standardized, vec![0.0, 0.0]);
797 }
798
799 #[test]
800 fn test_standardize_xy_without_intercept() {
801 // Test the no-intercept path (lines 264-278)
802 let x_data = vec![2.0, 3.0, 4.0, 6.0, 8.0, 9.0]; // No intercept column
803 let x = Matrix::new(2, 3, x_data);
804 let y = vec![3.0, 5.0];
805
806 let options = StandardizeOptions {
807 intercept: false, // No intercept
808 standardize_x: true,
809 standardize_y: false,
810 weights: None,
811 };
812
813 let (_x_standardized, y_standardized, info) = standardize_xy(&x, &y, &options);
814
815 // Without intercept, y_mean should be 0
816 assert_eq!(info.y_mean, 0.0);
817 assert!(!info.intercept);
818
819 // y should still be scaled to unit norm
820 let y_norm: f64 = y_standardized.iter().map(|&v| v * v).sum::<f64>().sqrt();
821 assert!((y_norm - 1.0).abs() < 1e-10);
822 }
823
824 #[test]
825 fn test_standardize_xy_constant_y() {
826 // Test the y_scale_val > 0.0 false branch (lines 258-262, 272-276)
827 // When y is constant, y_scale_val would be 0
828 let x_data = vec![1.0, 2.0, 3.0, 1.0, 4.0, 6.0];
829 let x = Matrix::new(2, 3, x_data);
830 let y = vec![5.0, 5.0]; // Constant y
831
832 let options = StandardizeOptions {
833 intercept: true,
834 standardize_x: true,
835 standardize_y: false,
836 weights: None,
837 };
838
839 let (_x_standardized, y_standardized, info) = standardize_xy(&x, &y, &options);
840
841 // With constant y, after centering we get zeros, so y_scale is 0
842 // y_standardized should be all zeros (can't normalize zero vector)
843 assert_eq!(y_standardized, vec![0.0, 0.0]);
844 assert_eq!(info.y_mean, 5.0);
845 // y_scale should be None or 0 since we can't normalize a zero vector
846 assert!(info.y_scale.unwrap_or(0.0) == 0.0);
847 }
848
849 #[test]
850 fn test_unstandardize_coefficients_no_intercept() {
851 // Test the no-intercept path in unstandardize_coefficients (line 489-490)
852 let x_mean = vec![0.0, 4.0, 6.0];
853 let x_scale = vec![1.0, 2.0, 3.0];
854 let column_squared_norms = vec![1.0, 1.0, 1.0];
855 let y_mean = 0.0;
856 let y_scale = Some(2.0);
857
858 let info = StandardizationInfo {
859 x_mean: x_mean.clone(),
860 x_scale: x_scale.clone(),
861 column_squared_norms,
862 y_mean,
863 y_scale,
864 y_scale_before_sqrt_weights_normalized: None,
865 intercept: false, // No intercept
866 standardized_x: true,
867 standardized_y: true,
868 };
869
870 // Coefficients without intercept marker (all are slopes when intercept=false)
871 let coefficients_standardized = vec![1.0, 2.0, 3.0];
872
873 let (beta0, beta_slopes) = unstandardize_coefficients(&coefficients_standardized, &info);
874
875 // With no intercept, beta0 should be 0
876 assert_eq!(beta0, 0.0);
877 // All 3 coefficients should be unstandardized
878 assert_eq!(beta_slopes.len(), 3);
879 // beta_slopes[j] = (y_scale * coefficients_standardized[j]) / x_scale[j]
880 assert!((beta_slopes[0] - 2.0).abs() < 1e-10); // (2 * 1) / 1
881 assert!((beta_slopes[1] - (2.0 * 2.0 / 2.0)).abs() < 1e-10); // (2 * 2) / 2 = 2.0
882 assert!((beta_slopes[2] - (2.0 * 3.0 / 3.0)).abs() < 1e-10); // (2 * 3) / 3 = 2.0
883 }
884
885 #[test]
886 fn test_unstandardize_coefficients_no_y_scale() {
887 // Test the y_scale.unwrap_or(1.0) path (line 463)
888 let x_mean = vec![0.0, 4.0, 6.0];
889 let x_scale = vec![1.0, 2.0, 3.0];
890 let column_squared_norms = vec![1.0, 1.0, 1.0];
891 let y_mean = 5.0;
892 let y_scale = None; // No y_scale
893
894 let info = StandardizationInfo {
895 x_mean: x_mean.clone(),
896 x_scale: x_scale.clone(),
897 column_squared_norms,
898 y_mean,
899 y_scale,
900 y_scale_before_sqrt_weights_normalized: None,
901 intercept: true,
902 standardized_x: true,
903 standardized_y: false,
904 };
905
906 let coefficients_standardized = vec![0.0, 1.0, 2.0];
907
908 let (_beta0, beta_slopes) = unstandardize_coefficients(&coefficients_standardized, &info);
909
910 // Should use y_scale = 1.0
911 assert!((beta_slopes[0] - 0.5).abs() < 1e-10); // (1 * 1) / 2
912 }
913
914 #[test]
915 fn test_predict_no_intercept_column() {
916 // Test when beta.len() == p (no intercept column in x)
917 let x_data = vec![2.0, 3.0, 4.0, 6.0];
918 let x = Matrix::new(2, 2, x_data); // 2x2, no intercept column
919 let beta0 = 1.0;
920 let beta = vec![2.0, 3.0]; // 2 coefficients for 2 columns
921
922 let preds = predict(&x, beta0, &beta);
923
924 // pred[0] = 1 + 2*2 + 3*3 = 1 + 4 + 9 = 14
925 assert!((preds[0] - 14.0).abs() < 1e-10);
926 // pred[1] = 1 + 2*4 + 3*6 = 1 + 8 + 18 = 27
927 assert!((preds[1] - 27.0).abs() < 1e-10);
928 }
929
930 #[test]
931 fn test_predict_beta_longer_than_columns() {
932 // Test the col < p branch (line 544-545)
933 let x_data = vec![1.0, 2.0, 3.0];
934 let x = Matrix::new(1, 3, x_data);
935 let beta0 = 5.0;
936 let beta = vec![1.0, 2.0, 3.0, 4.0]; // More betas than columns
937
938 let preds = predict(&x, beta0, &beta);
939
940 // Should only use first 3 betas (matching columns)
941 // p=3, beta.len()=4, has_intercept_col=false, so uses betas[0..2]
942 // Wait: beta.len() (4) != p-1 (2), so has_intercept_col=false
943 // Uses betas[0..min(4, 3)] = betas[0..3]
944 // pred[0] = 5 + 1*1 + 2*2 + 3*3 = 5 + 1 + 4 + 9 = 19
945 assert!((preds[0] - 19.0).abs() < 1e-10);
946 }
947
948 #[test]
949 fn test_standardize_xy_no_standardize_x() {
950 // Test standardize_x=false path (lines 319-323, 368-372)
951 let x_data = vec![1.0, 2.0, 3.0, 1.0, 4.0, 6.0];
952 let x = Matrix::new(2, 3, x_data);
953 let y = vec![3.0, 5.0];
954
955 let options = StandardizeOptions {
956 intercept: true,
957 standardize_x: false, // Don't standardize X
958 standardize_y: false,
959 weights: None,
960 };
961
962 let (x_standardized, _y_standardized, info) = standardize_xy(&x, &y, &options);
963
964 // Intercept column should still be unchanged
965 assert_eq!(x_standardized.get(0, 0), 1.0);
966 assert_eq!(x_standardized.get(1, 0), 1.0);
967
968 // When standardize_x=false with intercept:
969 // - Data is centered: x_centered = x - mean
970 // - Then transformed by sqrt_weights_normalized: x_new = sqrt(1/n) * x_centered
971 // Column 1: [2, 4] -> mean=3 -> centered: [-1, 1] -> sqrt(1/2) * centered
972 let sqrt_half = (0.5_f64).sqrt();
973 assert!((x_standardized.get(0, 1) - (-sqrt_half)).abs() < 1e-10);
974 assert!((x_standardized.get(1, 1) - sqrt_half).abs() < 1e-10);
975
976 // x_scale should be 1.0 for non-standardized columns
977 assert_eq!(info.x_scale[1], 1.0);
978 assert_eq!(info.x_scale[2], 1.0);
979 }
980}