linreg_core/regularized/preprocess.rs
1//! Data preprocessing for regularized regression.
2//!
3//! This module provides standardization utilities that match glmnet output behavior:
4//!
5//! - Predictors are centered and scaled (if enabled)
6//! - The intercept column is not penalized, so it's handled specially
7//! - Coefficients can be unstandardized back to the original scale
8//! - Observation weights are supported for weighted regression
9//!
10//! # Standardization Convention
11//!
12//! The scaling factor used is `sqrt(sum(x²) / n)`, which gives unit variance
13//! under the 1/n convention (matching the glmnet paper).
14//!
15//! # Weighted Standardization
16//!
17//! When weights are provided, they are first normalized to sum to 1:
18//! `weights_normalized = w / sum(w)`. Then weighted means and variances are computed.
19
20use crate::linalg::Matrix;
21
22/// Information stored during standardization, used to unstandardize coefficients.
23///
24/// This struct captures all the information needed to transform coefficients
25/// from the standardized space back to the original data scale.
26///
27/// # Fields
28///
29/// * `x_mean` - Mean of each predictor column (length p)
30/// * `x_scale` - Scale factor for each predictor column (length p)
31/// * `y_mean` - Mean of response variable
32/// * `y_scale` - Scale factor for response (optional, used for lambda path)
33/// * `intercept` - Whether an intercept term was included
34/// * `standardized_x` - Whether X was standardized
35/// * `standardized_y` - Whether y was standardized
36///
37/// # Example
38///
39/// ```
40/// # use linreg_core::regularized::preprocess::StandardizationInfo;
41/// let info = StandardizationInfo {
42/// x_mean: vec![0.0, 5.0],
43/// x_scale: vec![1.0, 2.0],
44/// column_squared_norms: vec![1.0, 1.0],
45/// y_mean: 10.0,
46/// y_scale: Some(3.0),
47/// y_scale_before_sqrt_weights_normalized: Some(3.0),
48/// intercept: true,
49/// standardized_x: true,
50/// standardized_y: false,
51/// };
52///
53/// assert_eq!(info.x_mean.len(), 2);
54/// assert!(info.intercept);
55/// ```
56#[derive(Clone, Debug)]
57pub struct StandardizationInfo {
58 /// Mean of each predictor column
59 pub x_mean: Vec<f64>,
60 /// Scale factor for each predictor column
61 pub x_scale: Vec<f64>,
62 /// Squared norm of each predictor column after standardization.
63 /// This is used in the coordinate descent update denominator.
64 /// - With intercept and standardize: column_squared_norms\[j\] = 1.0 (unit norm after centering)
65 /// - Without intercept and standardize: column_squared_norms\[j\] = 1.0 + x_squared_mean/x_centered_variance (glmnet formula)
66 /// - Without standardize: column_squared_norms\[j\] = ||x_j||^2 (actual squared norm)
67 pub column_squared_norms: Vec<f64>,
68 /// Mean of response variable
69 pub y_mean: f64,
70 /// Scale factor for response (for lambda path construction)
71 /// This is the norm AFTER sqrt_weights_normalized transformation and centering: sqrt(sum((sqrt_weights_normalized*(y-ym))^2))
72 pub y_scale: Option<f64>,
73 /// Scale factor for response BEFORE sqrt_weights_normalized transformation: sqrt(sum((y-ym)^2))
74 /// This is used for lambda scaling between original and standardized data
75 pub y_scale_before_sqrt_weights_normalized: Option<f64>,
76 /// Whether an intercept was included
77 pub intercept: bool,
78 /// Whether X was standardized
79 pub standardized_x: bool,
80 /// Whether y was standardized
81 pub standardized_y: bool,
82}
83
84/// Options for standardization.
85///
86/// # Fields
87///
88/// * `intercept` - Whether to include/center an intercept (default: true)
89/// * `standardize_x` - Whether to standardize predictors (default: true)
90/// * `standardize_y` - Whether to standardize response (default: false)
91/// * `weights` - Optional observation weights (default: None)
92/// If provided, weights are normalized to sum to 1 before use.
93///
94/// # Note
95///
96/// Setting `standardize_y` to `true` is mainly useful when you want to match
97/// glmnet's lambda sequence exactly. For single-lambda fits, you typically
98/// don't need to standardize y.
99///
100/// # Example
101///
102/// ```
103/// # use linreg_core::regularized::preprocess::StandardizeOptions;
104/// let opts = StandardizeOptions {
105/// intercept: true,
106/// standardize_x: true,
107/// standardize_y: false,
108/// weights: None,
109/// };
110///
111/// assert!(opts.intercept);
112/// assert!(opts.standardize_x);
113/// ```
114#[derive(Clone, Debug)]
115pub struct StandardizeOptions {
116 /// Whether to include an intercept (and center X)
117 pub intercept: bool,
118 /// Whether to standardize predictor columns
119 pub standardize_x: bool,
120 /// Whether to standardize the response variable
121 pub standardize_y: bool,
122 /// Optional observation weights (normalized to sum to 1)
123 pub weights: Option<Vec<f64>>,
124}
125
126impl Default for StandardizeOptions {
127 fn default() -> Self {
128 StandardizeOptions {
129 intercept: true,
130 standardize_x: true,
131 standardize_y: false,
132 weights: None,
133 }
134 }
135}
136
137/// Standardizes X and y for regularized regression (glmnet-compatible).
138///
139/// This function performs the same standardization as glmnet with
140/// `standardize=TRUE`. The first column of X is assumed to be the intercept
141/// (all ones) and is NOT standardized.
142///
143/// # Arguments
144///
145/// * `x` - Design matrix (n × p). First column should be intercept if `intercept=true`.
146/// * `y` - Response vector (n elements)
147/// * `options` - Standardization options (including optional observation weights)
148///
149/// # Returns
150///
151/// A tuple `(x_standardized, y_standardized, info)` where:
152/// - `x_standardized` is the standardized design matrix
153/// - `y_standardized` is the (optionally) standardized response
154/// - `info` contains the standardization parameters for unstandardization
155///
156/// # Standardization Details
157///
158/// ## Unweighted case:
159/// For the intercept column (first column, if present):
160/// - Not centered (stays as ones)
161/// - Not scaled
162///
163/// For other columns (if `standardize_x=true`):
164/// - Centered: `x_centered = x - mean(x)`
165/// - Scaled: `x_scaled = x_centered / sqrt(sum(x²))`
166///
167/// For y (if `standardize_y=true`):
168/// - Centered: `y_centered = y - mean(y)`
169/// - Scaled: `y_scaled = y_centered / sqrt(sum(y²))`
170///
171/// ## Weighted case:
172/// Weights are first normalized: `weights_normalized = w / sum(w)`, then `sqrt_weights_normalized = sqrt(weights_normalized)`
173/// - Weighted mean: `ym = sum(w * y) / sum(w) = sum(weights_normalized * y)`
174/// - Weighted variance: `sum(w * (y - ym)^2) / sum(w)`
175/// - Data transformed by `sqrt_weights_normalized`: `y_new = sqrt_weights_normalized * (y - ym)`, then scaled
176#[allow(clippy::needless_range_loop)]
177pub fn standardize_xy(
178 x: &Matrix,
179 y: &[f64],
180 options: &StandardizeOptions,
181) -> (Matrix, Vec<f64>, StandardizationInfo) {
182 let n = x.rows;
183 let p = x.cols;
184
185 // Validate weights if provided
186 if let Some(ref w) = options.weights {
187 if w.len() != n {
188 return (
189 Matrix::new(n, p, vec![0.0; n * p]),
190 vec![0.0; n],
191 StandardizationInfo {
192 x_mean: vec![0.0; p],
193 x_scale: vec![1.0; p],
194 column_squared_norms: vec![0.0; p],
195 y_mean: 0.0,
196 y_scale: None,
197 y_scale_before_sqrt_weights_normalized: None,
198 intercept: options.intercept,
199 standardized_x: options.standardize_x,
200 standardized_y: options.standardize_y,
201 },
202 );
203 }
204 if w.iter().any(|&wi| wi < 0.0) {
205 panic!("Weights must be non-negative");
206 }
207 }
208
209 // Prepare normalized weights and sqrt(weights)
210 // w = w / sum(w) then sqrt_weights_normalized = sqrt(w)
211 let (weights_normalized, sqrt_weights_normalized): (Vec<f64>, Vec<f64>) = if let Some(ref w) = options.weights {
212 let w_sum: f64 = w.iter().sum();
213 if w_sum > 0.0 {
214 let weights_normalized_vec: Vec<f64> = w.iter().map(|&wi| wi / w_sum).collect();
215 let sqrt_weights_normalized_vec: Vec<f64> = weights_normalized_vec.iter().map(|&wi| wi.sqrt()).collect();
216 (weights_normalized_vec, sqrt_weights_normalized_vec)
217 } else {
218 (vec![0.0; n], vec![0.0; n])
219 }
220 } else {
221 // No weights: use uniform weights
222 let w_uniform = vec![1.0 / n as f64; n];
223 let sqrt_weights_normalized_uniform = vec![1.0 / (n as f64).sqrt(); n];
224 (w_uniform, sqrt_weights_normalized_uniform)
225 };
226
227 let mut x_standardized = x.clone();
228 let mut y_standardized = y.to_vec();
229
230 let mut x_mean = vec![0.0; p];
231 let mut x_scale = vec![1.0; p];
232 let mut column_squared_norms = vec![0.0; p]; // Column squared norms for coordinate descent
233
234 let y_mean = if options.intercept && !y.is_empty() {
235 // Weighted mean: ym = sum(w * y)
236 weights_normalized.iter().zip(y.iter()).map(|(&w, &yi)| w * yi).sum()
237 } else {
238 0.0
239 };
240
241 // GLMNET: y is ALWAYS scaled to unit norm!
242 // This is critical for correct lambda_max computation
243 let (y_scale, y_scale_before_sqrt_weights_normalized) = if options.intercept {
244 // WITH INTERCEPT: Center y, then scale to unit norm
245 // First compute y_scale_before_sqrt_weights_normalized (centered but not sqrt_weights_normalized-transformed)
246 let y_centered: Vec<f64> = y.iter().map(|&yi| yi - y_mean).collect();
247 let y_ss_before_sqrt_weights_normalized: f64 = y_centered.iter().map(|&yi| yi * yi).sum();
248 let y_scale_before_sqrt_weights_normalized_val = y_ss_before_sqrt_weights_normalized.sqrt();
249
250 // Center y: y_new = sqrt_weights_normalized * (y - ym)
251 for (yi, &sqrt_weight) in y_standardized.iter_mut().zip(&sqrt_weights_normalized) {
252 *yi = sqrt_weight * (*yi - y_mean);
253 }
254
255 // Scale to unit norm (GLMNET always does this!)
256 let y_ss: f64 = y_standardized.iter().map(|&yi| yi * yi).sum();
257 let y_scale_val = y_ss.sqrt();
258 if y_scale_val > 0.0 {
259 for yi in y_standardized.iter_mut() {
260 *yi /= y_scale_val;
261 }
262 }
263 (Some(y_scale_val), Some(y_scale_before_sqrt_weights_normalized_val))
264 } else {
265 // WITHOUT INTERCEPT: Don't center y, but DO scale to unit norm (GLMNET output behavior!)
266 // y_new = sqrt_weights_normalized * y, then y = y / ||y||
267 for (yi, &sqrt_weight) in y_standardized.iter_mut().zip(&sqrt_weights_normalized) {
268 *yi *= sqrt_weight;
269 }
270 let y_ss: f64 = y_standardized.iter().map(|&yi| yi * yi).sum();
271 let y_scale_val = y_ss.sqrt();
272 if y_scale_val > 0.0 {
273 for yi in y_standardized.iter_mut() {
274 *yi /= y_scale_val;
275 }
276 }
277 (Some(y_scale_val), Some(y_scale_val)) // y_scale_before_sqrt_weights_normalized = y_scale when no centering
278 };
279
280 // Standardize X columns
281 // If intercept is present, first column is NOT standardized (it's the intercept column)
282 let first_penalized_column_index = if options.intercept { 1 } else { 0 };
283
284 if options.intercept {
285 // WITH INTERCEPT (intercept column not standardized)
286 for j in first_penalized_column_index..p {
287 // Compute weighted column mean and center
288 let col_mean: f64 = (0..n)
289 .map(|i| x_standardized.get(i, j) * weights_normalized[i])
290 .sum();
291 x_mean[j] = col_mean;
292
293 // Center the column and apply sqrt_weights_normalized transformation
294 // x_new = sqrt_weights_normalized * (x - xm)
295 for i in 0..n {
296 let val = sqrt_weights_normalized[i] * (x_standardized.get(i, j) - col_mean);
297 x_standardized.set(i, j, val);
298 }
299
300 // Compute squared norm
301 let col_squared_norm_val: f64 = (0..n)
302 .map(|i| {
303 let val = x_standardized.get(i, j);
304 val * val
305 })
306 .sum();
307
308 if options.standardize_x {
309 // Scale to unit norm
310 let col_scale = col_squared_norm_val.sqrt();
311 if col_scale > 0.0 {
312 for i in 0..n {
313 let val = x_standardized.get(i, j) / col_scale;
314 x_standardized.set(i, j, val);
315 }
316 x_scale[j] = col_scale;
317 column_squared_norms[j] = 1.0; // Unit norm
318 }
319 } else {
320 // No standardization - column_squared_norms stays as the actual squared norm
321 column_squared_norms[j] = col_squared_norm_val;
322 x_scale[j] = 1.0;
323 }
324 }
325 } else {
326 // WITHOUT INTERCEPT (no centering)
327 for j in first_penalized_column_index..p {
328 x_mean[j] = 0.0; // No centering
329
330 // Apply sqrt_weights_normalized transformation
331 for i in 0..n {
332 let val = sqrt_weights_normalized[i] * x_standardized.get(i, j);
333 x_standardized.set(i, j, val);
334 }
335
336 // Compute squared norm after sqrt_weights_normalized transformation
337 let col_squared_norm_val: f64 = (0..n)
338 .map(|i| {
339 let val = x_standardized.get(i, j);
340 val * val
341 })
342 .sum();
343
344 if options.standardize_x {
345 // GLMNET special formula for no-intercept case:
346 // x_squared_mean = dot_product(sqrt_weights_normalized, x)^2 (squared mean)
347 // x_centered_variance = col_squared_norm - x_squared_mean (variance-like quantity)
348 // xs = sqrt(x_centered_variance)
349 // column_squared_norms_final = 1.0 + x_squared_mean / x_centered_variance
350 let x_squared_mean: f64 = (0..n)
351 .map(|i| sqrt_weights_normalized[i] * x_standardized.get(i, j))
352 .sum::<f64>().powi(2);
353 let x_centered_variance = col_squared_norm_val - x_squared_mean;
354
355 if x_centered_variance > 0.0 {
356 let col_scale = x_centered_variance.sqrt();
357 // Scale by col_scale (NOT by sqrt(col_squared_norm_val))
358 for i in 0..n {
359 let val = x_standardized.get(i, j) / col_scale;
360 x_standardized.set(i, j, val);
361 }
362 x_scale[j] = col_scale;
363 column_squared_norms[j] = 1.0 + x_squared_mean / x_centered_variance; // GLMNET formula
364 } else {
365 column_squared_norms[j] = 1.0;
366 x_scale[j] = 1.0;
367 }
368 } else {
369 // No standardization
370 column_squared_norms[j] = col_squared_norm_val;
371 x_scale[j] = 1.0;
372 }
373 }
374 }
375
376 // If intercept column exists, set its scale to 1.0 (not penalized)
377 if options.intercept && p > 0 {
378 x_scale[0] = 1.0;
379 x_mean[0] = 0.0; // Intercept column has no "mean" to subtract
380 column_squared_norms[0] = 1.0; // Intercept column is not penalized
381 }
382
383 let info = StandardizationInfo {
384 x_mean,
385 x_scale,
386 column_squared_norms,
387 y_mean,
388 y_scale,
389 y_scale_before_sqrt_weights_normalized,
390 intercept: options.intercept,
391 standardized_x: options.standardize_x,
392 standardized_y: options.standardize_y,
393 };
394
395 (x_standardized, y_standardized, info)
396}
397
398/// Unstandardizes coefficients from the standardized space back to original scale.
399///
400/// This reverses the standardization transformation to get coefficients that
401/// can be applied to the original (unscaled) data.
402///
403/// # Arguments
404///
405/// * `coefficients_standardized` - Coefficients in standardized space (length p)
406/// * `info` - Standardization information from [`standardize_xy`]
407///
408/// # Returns
409///
410/// A tuple `(beta0, beta_slopes)` where:
411/// - `beta0` is the intercept on the original scale
412/// - `beta_slopes` are the slope coefficients only (excluding intercept column coefficient)
413///
414/// # Unstandardization Formula
415///
416/// For non-intercept coefficients:
417/// ```text
418/// β_original[j] = (y_scale * β_std[j]) / x_scale[j]
419/// ```
420///
421/// For the intercept:
422/// ```text
423/// β₀ = y_mean - Σⱼ x_mean[j] * β_original[j]
424/// ```
425///
426/// If y was not standardized, `y_scale = 1`.
427/// If X was not standardized, `x_scale[j] = 1`.
428///
429/// # Note
430///
431/// If `intercept=true` in the info, `coefficients_standardized[0]` is assumed to be the intercept
432/// coefficient (which is already 0 in the standardized space since X was centered).
433/// The returned `beta_slopes` will NOT include this zeroed coefficient - only actual
434/// slope coefficients are returned.
435///
436/// # Example
437///
438/// ```
439/// # use linreg_core::regularized::preprocess::{unstandardize_coefficients, StandardizationInfo};
440/// let info = StandardizationInfo {
441/// x_mean: vec![0.0, 5.0],
442/// x_scale: vec![1.0, 2.0],
443/// column_squared_norms: vec![1.0, 1.0],
444/// y_mean: 10.0,
445/// y_scale: Some(3.0),
446/// y_scale_before_sqrt_weights_normalized: Some(3.0),
447/// intercept: true,
448/// standardized_x: true,
449/// standardized_y: false,
450/// };
451///
452/// // Standardized coefficients: [intercept=0, slope1=2.0]
453/// let coefficients_standardized = vec![0.0, 2.0];
454/// let (beta0, beta_slopes) = unstandardize_coefficients(&coefficients_standardized, &info);
455///
456/// // slope_original = (y_scale * slope_std) / x_scale[1]
457/// // = (3.0 * 2.0) / 2.0 = 3.0
458/// assert!((beta_slopes[0] - 3.0).abs() < 0.01);
459/// ```
460#[allow(clippy::needless_range_loop)]
461pub fn unstandardize_coefficients(coefficients_standardized: &[f64], info: &StandardizationInfo) -> (f64, Vec<f64>) {
462 let p = coefficients_standardized.len();
463 let y_scale = info.y_scale.unwrap_or(1.0);
464
465 // Determine where slope coefficients start in coefficients_standardized
466 let start_idx = if info.intercept { 1 } else { 0 };
467 let n_slopes = p - start_idx;
468
469 // Unstandardize slope coefficients only (exclude intercept column coefficient)
470 // NOTE: X is ALWAYS standardized for the solver, so we always apply the unstandardization formula.
471 // The user's `standardize_x` option doesn't affect the internal computation, only the
472 // interpretation of results.
473 let mut beta_slopes = vec![0.0; n_slopes];
474 for j in start_idx..p {
475 let slope_idx = j - start_idx;
476 // Standard unstandardization: beta_original = (y_scale * coefficients_standardized) / x_scale
477 // This converts from the standardized space back to original data scale
478 beta_slopes[slope_idx] = (y_scale * coefficients_standardized[j]) / info.x_scale[j];
479 }
480
481 // Compute intercept on original scale
482 // beta0 = y_mean - sum(x_mean[j] * beta_slopes[j-1]) for j in 1..p
483 let beta0 = if info.intercept {
484 let mut sum = 0.0;
485 for j in 1..p {
486 sum += info.x_mean[j] * beta_slopes[j - 1];
487 }
488 info.y_mean - sum
489 } else {
490 0.0
491 };
492
493 (beta0, beta_slopes)
494}
495
496/// Computes predictions using unstandardized coefficients.
497///
498/// # Arguments
499///
500/// * `x_new` - New data matrix (n_new × p, with intercept column if applicable)
501/// * `beta0` - Intercept on original scale
502/// * `beta` - Slope coefficients on original scale (does NOT include intercept column coefficient)
503///
504/// # Returns
505///
506/// Predictions for each row in x_new.
507///
508/// # Note
509///
510/// If `x_new` has an intercept column (first column of all ones), `beta` should have
511/// `p - 1` elements corresponding to the non-intercept columns. If `x_new` has no
512/// intercept column, `beta` should have `p` elements.
513///
514/// # Example
515///
516/// ```
517/// # use linreg_core::regularized::preprocess::predict;
518/// # use linreg_core::linalg::Matrix;
519/// // X matrix with intercept: [[1, 2], [1, 3], [1, 4]]
520/// let x_new = Matrix::new(3, 2, vec![1.0, 2.0, 1.0, 3.0, 1.0, 4.0]);
521/// let beta0 = 1.0;
522/// let beta = vec![2.0]; // One slope coefficient
523///
524/// // predictions[i] = 1.0 + 2.0 * x[i,1]
525/// let preds = predict(&x_new, beta0, &beta);
526/// assert_eq!(preds, vec![5.0, 7.0, 9.0]);
527/// ```
528#[allow(clippy::needless_range_loop)]
529pub fn predict(x_new: &Matrix, beta0: f64, beta: &[f64]) -> Vec<f64> {
530 let n = x_new.rows;
531 let p = x_new.cols;
532
533 let mut predictions = vec![0.0; n];
534
535 // Determine if there's an intercept column based on beta length
536 // If beta has one fewer element than columns, first column is intercept
537 let has_intercept_col = beta.len() == p - 1;
538 let first_penalized_column_index = if has_intercept_col { 1 } else { 0 };
539
540 for i in 0..n {
541 let mut sum = beta0;
542 for (j, &beta_j) in beta.iter().enumerate() {
543 let col = first_penalized_column_index + j;
544 if col < p {
545 sum += x_new.get(i, col) * beta_j;
546 }
547 }
548 predictions[i] = sum;
549 }
550
551 predictions
552}
553
554#[cfg(test)]
555mod tests {
556 use super::*;
557
558 #[test]
559 fn test_standardize_xy_with_intercept() {
560 // Simple test data
561 let x_data = vec![1.0, 2.0, 3.0, 1.0, 4.0, 6.0, 1.0, 6.0, 9.0];
562 let x = Matrix::new(3, 3, x_data);
563 let y = vec![3.0, 5.0, 7.0];
564
565 let options = StandardizeOptions {
566 intercept: true,
567 standardize_x: true,
568 standardize_y: false, // Note: y is still scaled to unit norm by glmnet convention
569 weights: None,
570 };
571
572 let (x_standardized, y_standardized, info) = standardize_xy(&x, &y, &options);
573
574 // First column (intercept) should be unchanged
575 assert_eq!(x_standardized.get(0, 0), 1.0);
576 assert_eq!(x_standardized.get(1, 0), 1.0);
577 assert_eq!(x_standardized.get(2, 0), 1.0);
578
579 // GLMNET: y is ALWAYS scaled to unit norm
580 // y_centered = y - y_mean = [-2, 0, 2]
581 // sqrt_weights_normalized-transform: y_sqrt_weights = sqrt_weights_normalized * y_centered = [-2/sqrt(3), 0, 2/sqrt(3)]
582 // Scale to unit norm: y_standardized = y_sqrt_weights / ||y_sqrt_weights|| = [-1/sqrt(2), 0, 1/sqrt(2)]
583 let inv_sqrt2 = 1.0 / (2.0_f64).sqrt();
584 assert!((y_standardized[0] - (-inv_sqrt2)).abs() < 1e-10);
585 assert!((y_standardized[1] - 0.0).abs() < 1e-10);
586 assert!((y_standardized[2] - inv_sqrt2).abs() < 1e-10);
587
588 // x_mean should capture the column means
589 assert_eq!(info.x_mean[0], 0.0); // Intercept column
590 assert!((info.x_mean[1] - 4.0).abs() < 1e-10);
591 assert!((info.x_mean[2] - 6.0).abs() < 1e-10);
592 }
593
594 #[test]
595 fn test_unstandardize_coefficients() {
596 // Create a simple standardization scenario
597 let x_mean = vec![0.0, 4.0, 6.0];
598 let x_scale = vec![1.0, 2.0, 3.0];
599 let column_squared_norms = vec![1.0, 1.0, 1.0]; // Unit norm after standardization
600 let y_mean = 5.0;
601 let y_scale = Some(2.0);
602
603 let info = StandardizationInfo {
604 x_mean: x_mean.clone(),
605 x_scale: x_scale.clone(),
606 column_squared_norms,
607 y_mean,
608 y_scale,
609 y_scale_before_sqrt_weights_normalized: None,
610 intercept: true,
611 standardized_x: true,
612 standardized_y: true,
613 };
614
615 // Coefficients in standardized space: [intercept=0, beta1=1, beta2=2]
616 let coefficients_standardized = vec![0.0, 1.0, 2.0];
617
618 let (beta0, beta_slopes) = unstandardize_coefficients(&coefficients_standardized, &info);
619
620 // Check unstandardization - beta_slopes now only contains slope coefficients
621 // beta_slopes[0] = (y_scale * coefficients_standardized[1]) / x_scale[1] = (2 * 1) / 2 = 1
622 assert!((beta_slopes[0] - 1.0).abs() < 1e-10);
623 // beta_slopes[1] = (y_scale * coefficients_standardized[2]) / x_scale[2] = (2 * 2) / 3 = 4/3
624 assert!((beta_slopes[1] - 4.0 / 3.0).abs() < 1e-10);
625
626 // beta0 = y_mean - sum(x_mean[j] * beta_slopes[j-1])
627 // = 5 - (4 * 1 + 6 * 4/3) = 5 - 4 - 8 = -7
628 assert!((beta0 - (-7.0)).abs() < 1e-10);
629
630 // Verify beta_slopes has the correct length (only slopes, not intercept col coef)
631 assert_eq!(beta_slopes.len(), 2);
632 }
633
634 #[test]
635 fn test_predict() {
636 // X has intercept column (first col all 1s) plus 2 predictors
637 let x_data = vec![1.0, 2.0, 3.0, 1.0, 4.0, 6.0];
638 let x = Matrix::new(2, 3, x_data);
639
640 // beta0 = 1, beta = [2.0, 3.0] (slope coefficients only, no intercept col coef)
641 let beta0 = 1.0;
642 let beta = vec![2.0, 3.0];
643
644 let preds = predict(&x, beta0, &beta);
645
646 // pred[0] = 1 + 2*2 + 3*3 = 1 + 4 + 9 = 14
647 assert!((preds[0] - 14.0).abs() < 1e-10);
648 // pred[1] = 1 + 2*4 + 3*6 = 1 + 8 + 18 = 27
649 assert!((preds[1] - 27.0).abs() < 1e-10);
650 }
651
652 #[test]
653 fn test_weighted_standardize_xy() {
654 // Simple test data
655 let x_data = vec![1.0, 2.0, 3.0, 1.0, 4.0, 6.0, 1.0, 6.0, 9.0];
656 let x = Matrix::new(3, 3, x_data);
657 let y = vec![3.0, 5.0, 7.0];
658
659 // Weights: give more weight to the middle observation
660 let weights = vec![1.0, 2.0, 1.0];
661
662 let options = StandardizeOptions {
663 intercept: true,
664 standardize_x: true,
665 standardize_y: false, // Note: y is still scaled to unit norm by glmnet convention
666 weights: Some(weights),
667 };
668
669 let (x_standardized, y_standardized, info) = standardize_xy(&x, &y, &options);
670
671 // First column (intercept) should be unchanged
672 assert_eq!(x_standardized.get(0, 0), 1.0);
673 assert_eq!(x_standardized.get(1, 0), 1.0);
674 assert_eq!(x_standardized.get(2, 0), 1.0);
675
676 // y_mean should be weighted mean
677 // weights normalized: [1/4, 2/4, 1/4] = [0.25, 0.5, 0.25]
678 // weighted mean: 0.25*3 + 0.5*5 + 0.25*7 = 0.75 + 2.5 + 1.75 = 5.0
679 assert!((info.y_mean - 5.0).abs() < 1e-10);
680
681 // GLMNET: y is ALWAYS scaled to unit norm
682 // y_centered = y - y_mean = [-2, 0, 2]
683 // sqrt_weights_normalized = sqrt([0.25, 0.5, 0.25]) = [0.5, ~0.707, 0.5]
684 // y_sqrt_weights = sqrt_weights_normalized * y_centered = [-1, 0, 1]
685 // sum(y_sqrt_weights^2) = 2, so y_scale = sqrt(2)
686 // y_standardized = y_sqrt_weights / y_scale = [-1/sqrt(2), 0, 1/sqrt(2)]
687 let expected_0 = -1.0 / (2.0_f64).sqrt();
688 assert!((y_standardized[0] - expected_0).abs() < 1e-10);
689 assert!((y_standardized[1] - 0.0).abs() < 1e-10);
690 assert!((y_standardized[2] + expected_0).abs() < 1e-10); // Should be 1/sqrt(2)
691 }
692
693 #[test]
694 fn test_weighted_standardize_uniform_weights() {
695 // Test that uniform weights give same result as no weights
696 let x_data = vec![1.0, 2.0, 3.0, 1.0, 4.0, 6.0];
697 let x = Matrix::new(2, 3, x_data);
698 let y = vec![3.0, 5.0];
699
700 // Uniform weights (should be equivalent to no weights after normalization)
701 let weights = vec![1.0, 1.0];
702
703 let options_with_weights = StandardizeOptions {
704 intercept: true,
705 standardize_x: true,
706 standardize_y: false,
707 weights: Some(weights),
708 };
709
710 let options_no_weights = StandardizeOptions {
711 intercept: true,
712 standardize_x: true,
713 standardize_y: false,
714 weights: None,
715 };
716
717 let (_x_standardized_w, y_standardized_w, info_w) = standardize_xy(&x, &y, &options_with_weights);
718 let (_x_standardized, y_standardized, info) = standardize_xy(&x, &y, &options_no_weights);
719
720 // Results should be the same
721 assert_eq!(info_w.y_mean, info.y_mean);
722 for i in 0..2 {
723 assert!((y_standardized_w[i] - y_standardized[i]).abs() < 1e-10);
724 }
725 }
726}