Skip to main content

sklears_simd/
regression.rs

1//! SIMD-optimized regression operations
2//!
3//! This module provides vectorized implementations of common regression operations
4//! including least squares, ridge regression, and elastic net computations.
5
6#[cfg(feature = "no-std")]
7use alloc::{vec, vec::Vec};
8
9/// SIMD-optimized ordinary least squares computation
10/// Computes X^T * X and X^T * y for the normal equation
11pub fn least_squares_normal_equation(
12    x: &[&[f32]], // Design matrix (n_samples x n_features)
13    y: &[f32],    // Target values (n_samples)
14) -> (Vec<Vec<f32>>, Vec<f32>) {
15    let n_samples = x.len();
16    let n_features = if n_samples > 0 { x[0].len() } else { 0 };
17
18    assert!(!x.is_empty(), "Design matrix cannot be empty");
19    assert_eq!(
20        y.len(),
21        n_samples,
22        "Target length must match number of samples"
23    );
24
25    // Initialize X^T * X matrix
26    let mut xtx = vec![vec![0.0f32; n_features]; n_features];
27    // Initialize X^T * y vector
28    let mut xty = vec![0.0f32; n_features];
29
30    #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
31    {
32        if crate::simd_feature_detected!("avx2") {
33            unsafe { least_squares_avx2(x, y, &mut xtx, &mut xty) };
34            return (xtx, xty);
35        } else if crate::simd_feature_detected!("sse2") {
36            unsafe { least_squares_sse2(x, y, &mut xtx, &mut xty) };
37            return (xtx, xty);
38        }
39    }
40
41    least_squares_scalar(x, y, &mut xtx, &mut xty);
42    (xtx, xty)
43}
44
45fn least_squares_scalar(x: &[&[f32]], y: &[f32], xtx: &mut [Vec<f32>], xty: &mut [f32]) {
46    let n_samples = x.len();
47    let n_features = x[0].len();
48
49    // Compute X^T * X
50    for i in 0..n_features {
51        for j in 0..n_features {
52            let sum: f32 = x.iter().map(|row| row[i] * row[j]).sum();
53            xtx[i][j] = sum;
54        }
55    }
56
57    // Compute X^T * y
58    for i in 0..n_features {
59        let mut sum = 0.0;
60        for k in 0..n_samples {
61            sum += x[k][i] * y[k];
62        }
63        xty[i] = sum;
64    }
65}
66
67#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
68#[target_feature(enable = "sse2")]
69unsafe fn least_squares_sse2(x: &[&[f32]], y: &[f32], xtx: &mut [Vec<f32>], xty: &mut [f32]) {
70    use core::arch::x86_64::*;
71
72    let n_samples = x.len();
73    let n_features = x[0].len();
74
75    // Compute X^T * X with SIMD
76    for i in 0..n_features {
77        for j in 0..n_features {
78            let mut sum = _mm_setzero_ps();
79            let mut k = 0;
80
81            while k + 4 <= n_samples {
82                let xi_vec = _mm_setr_ps(x[k][i], x[k + 1][i], x[k + 2][i], x[k + 3][i]);
83                let xj_vec = _mm_setr_ps(x[k][j], x[k + 1][j], x[k + 2][j], x[k + 3][j]);
84                let prod = _mm_mul_ps(xi_vec, xj_vec);
85                sum = _mm_add_ps(sum, prod);
86                k += 4;
87            }
88
89            let mut result = [0.0f32; 4];
90            _mm_storeu_ps(result.as_mut_ptr(), sum);
91            let mut scalar_sum = result[0] + result[1] + result[2] + result[3];
92
93            while k < n_samples {
94                scalar_sum += x[k][i] * x[k][j];
95                k += 1;
96            }
97
98            xtx[i][j] = scalar_sum;
99        }
100    }
101
102    // Compute X^T * y with SIMD
103    for i in 0..n_features {
104        let mut sum = _mm_setzero_ps();
105        let mut k = 0;
106
107        while k + 4 <= n_samples {
108            let xi_vec = _mm_setr_ps(x[k][i], x[k + 1][i], x[k + 2][i], x[k + 3][i]);
109            let y_vec = _mm_loadu_ps(&y[k]);
110            let prod = _mm_mul_ps(xi_vec, y_vec);
111            sum = _mm_add_ps(sum, prod);
112            k += 4;
113        }
114
115        let mut result = [0.0f32; 4];
116        _mm_storeu_ps(result.as_mut_ptr(), sum);
117        let mut scalar_sum = result[0] + result[1] + result[2] + result[3];
118
119        while k < n_samples {
120            scalar_sum += x[k][i] * y[k];
121            k += 1;
122        }
123
124        xty[i] = scalar_sum;
125    }
126}
127
128#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
129#[target_feature(enable = "avx2")]
130unsafe fn least_squares_avx2(x: &[&[f32]], y: &[f32], xtx: &mut [Vec<f32>], xty: &mut [f32]) {
131    use core::arch::x86_64::*;
132
133    let n_samples = x.len();
134    let n_features = x[0].len();
135
136    // Compute X^T * X with SIMD
137    for i in 0..n_features {
138        for j in 0..n_features {
139            let mut sum = _mm256_setzero_ps();
140            let mut k = 0;
141
142            while k + 8 <= n_samples {
143                let xi_vec = _mm256_setr_ps(
144                    x[k][i],
145                    x[k + 1][i],
146                    x[k + 2][i],
147                    x[k + 3][i],
148                    x[k + 4][i],
149                    x[k + 5][i],
150                    x[k + 6][i],
151                    x[k + 7][i],
152                );
153                let xj_vec = _mm256_setr_ps(
154                    x[k][j],
155                    x[k + 1][j],
156                    x[k + 2][j],
157                    x[k + 3][j],
158                    x[k + 4][j],
159                    x[k + 5][j],
160                    x[k + 6][j],
161                    x[k + 7][j],
162                );
163                let prod = _mm256_mul_ps(xi_vec, xj_vec);
164                sum = _mm256_add_ps(sum, prod);
165                k += 8;
166            }
167
168            let mut result = [0.0f32; 8];
169            _mm256_storeu_ps(result.as_mut_ptr(), sum);
170            let mut scalar_sum = result.iter().sum::<f32>();
171
172            while k < n_samples {
173                scalar_sum += x[k][i] * x[k][j];
174                k += 1;
175            }
176
177            xtx[i][j] = scalar_sum;
178        }
179    }
180
181    // Compute X^T * y with SIMD
182    for i in 0..n_features {
183        let mut sum = _mm256_setzero_ps();
184        let mut k = 0;
185
186        while k + 8 <= n_samples {
187            let xi_vec = _mm256_setr_ps(
188                x[k][i],
189                x[k + 1][i],
190                x[k + 2][i],
191                x[k + 3][i],
192                x[k + 4][i],
193                x[k + 5][i],
194                x[k + 6][i],
195                x[k + 7][i],
196            );
197            let y_vec = _mm256_loadu_ps(&y[k]);
198            let prod = _mm256_mul_ps(xi_vec, y_vec);
199            sum = _mm256_add_ps(sum, prod);
200            k += 8;
201        }
202
203        let mut result = [0.0f32; 8];
204        _mm256_storeu_ps(result.as_mut_ptr(), sum);
205        let mut scalar_sum = result.iter().sum::<f32>();
206
207        while k < n_samples {
208            scalar_sum += x[k][i] * y[k];
209            k += 1;
210        }
211
212        xty[i] = scalar_sum;
213    }
214}
215
216/// SIMD-optimized ridge regression normal equation computation
217/// Computes (X^T * X + alpha * I) and X^T * y
218pub fn ridge_regression_normal_equation(
219    x: &[&[f32]], // Design matrix (n_samples x n_features)
220    y: &[f32],    // Target values (n_samples)
221    alpha: f32,   // Regularization parameter
222) -> (Vec<Vec<f32>>, Vec<f32>) {
223    let (mut xtx, xty) = least_squares_normal_equation(x, y);
224
225    // Add ridge regularization: X^T * X + alpha * I
226    for (i, row) in xtx.iter_mut().enumerate() {
227        row[i] += alpha;
228    }
229
230    (xtx, xty)
231}
232
233/// SIMD-optimized elastic net penalty computation
234/// Computes the elastic net penalty: alpha * l1_ratio * ||w||_1 + 0.5 * alpha * (1 - l1_ratio) * ||w||_2^2
235pub fn elastic_net_penalty(weights: &[f32], alpha: f32, l1_ratio: f32) -> f32 {
236    assert!(
237        (0.0..=1.0).contains(&l1_ratio),
238        "l1_ratio must be between 0 and 1"
239    );
240    assert!(alpha >= 0.0, "alpha must be non-negative");
241
242    if weights.is_empty() {
243        return 0.0;
244    }
245
246    #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
247    {
248        if crate::simd_feature_detected!("avx2") {
249            return unsafe { elastic_net_penalty_avx2(weights, alpha, l1_ratio) };
250        } else if crate::simd_feature_detected!("sse2") {
251            return unsafe { elastic_net_penalty_sse2(weights, alpha, l1_ratio) };
252        }
253    }
254
255    elastic_net_penalty_scalar(weights, alpha, l1_ratio)
256}
257
258fn elastic_net_penalty_scalar(weights: &[f32], alpha: f32, l1_ratio: f32) -> f32 {
259    let l1_norm: f32 = weights.iter().map(|w| w.abs()).sum();
260    let l2_norm_squared: f32 = weights.iter().map(|w| w * w).sum();
261
262    alpha * l1_ratio * l1_norm + 0.5 * alpha * (1.0 - l1_ratio) * l2_norm_squared
263}
264
265#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
266#[target_feature(enable = "sse2")]
267unsafe fn elastic_net_penalty_sse2(weights: &[f32], alpha: f32, l1_ratio: f32) -> f32 {
268    use core::arch::x86_64::*;
269
270    let mut l1_sum = _mm_setzero_ps();
271    let mut l2_sum = _mm_setzero_ps();
272    let sign_mask = _mm_set1_ps(-0.0f32);
273    let mut i = 0;
274
275    while i + 4 <= weights.len() {
276        let w_vec = _mm_loadu_ps(weights.as_ptr().add(i));
277
278        // L1 norm: sum of absolute values
279        let abs_w = _mm_andnot_ps(sign_mask, w_vec);
280        l1_sum = _mm_add_ps(l1_sum, abs_w);
281
282        // L2 norm squared: sum of squares
283        let squared_w = _mm_mul_ps(w_vec, w_vec);
284        l2_sum = _mm_add_ps(l2_sum, squared_w);
285
286        i += 4;
287    }
288
289    let mut l1_result = [0.0f32; 4];
290    let mut l2_result = [0.0f32; 4];
291    _mm_storeu_ps(l1_result.as_mut_ptr(), l1_sum);
292    _mm_storeu_ps(l2_result.as_mut_ptr(), l2_sum);
293
294    let mut l1_scalar = l1_result[0] + l1_result[1] + l1_result[2] + l1_result[3];
295    let mut l2_scalar = l2_result[0] + l2_result[1] + l2_result[2] + l2_result[3];
296
297    while i < weights.len() {
298        l1_scalar += weights[i].abs();
299        l2_scalar += weights[i] * weights[i];
300        i += 1;
301    }
302
303    alpha * l1_ratio * l1_scalar + 0.5 * alpha * (1.0 - l1_ratio) * l2_scalar
304}
305
306#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
307#[target_feature(enable = "avx2")]
308unsafe fn elastic_net_penalty_avx2(weights: &[f32], alpha: f32, l1_ratio: f32) -> f32 {
309    use core::arch::x86_64::*;
310
311    let mut l1_sum = _mm256_setzero_ps();
312    let mut l2_sum = _mm256_setzero_ps();
313    let sign_mask = _mm256_set1_ps(-0.0f32);
314    let mut i = 0;
315
316    while i + 8 <= weights.len() {
317        let w_vec = _mm256_loadu_ps(weights.as_ptr().add(i));
318
319        // L1 norm: sum of absolute values
320        let abs_w = _mm256_andnot_ps(sign_mask, w_vec);
321        l1_sum = _mm256_add_ps(l1_sum, abs_w);
322
323        // L2 norm squared: sum of squares
324        let squared_w = _mm256_mul_ps(w_vec, w_vec);
325        l2_sum = _mm256_add_ps(l2_sum, squared_w);
326
327        i += 8;
328    }
329
330    let mut l1_result = [0.0f32; 8];
331    let mut l2_result = [0.0f32; 8];
332    _mm256_storeu_ps(l1_result.as_mut_ptr(), l1_sum);
333    _mm256_storeu_ps(l2_result.as_mut_ptr(), l2_sum);
334
335    let mut l1_scalar = l1_result.iter().sum::<f32>();
336    let mut l2_scalar = l2_result.iter().sum::<f32>();
337
338    while i < weights.len() {
339        l1_scalar += weights[i].abs();
340        l2_scalar += weights[i] * weights[i];
341        i += 1;
342    }
343
344    alpha * l1_ratio * l1_scalar + 0.5 * alpha * (1.0 - l1_ratio) * l2_scalar
345}
346
347/// SIMD-optimized soft thresholding for LASSO
348/// Applies soft thresholding: sign(x) * max(|x| - threshold, 0)
349pub fn soft_threshold(values: &[f32], threshold: f32, output: &mut [f32]) {
350    assert_eq!(
351        values.len(),
352        output.len(),
353        "Arrays must have the same length"
354    );
355    assert!(threshold >= 0.0, "Threshold must be non-negative");
356
357    if values.is_empty() {
358        return;
359    }
360
361    #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
362    {
363        if crate::simd_feature_detected!("avx2") {
364            unsafe { soft_threshold_avx2(values, threshold, output) };
365            return;
366        } else if crate::simd_feature_detected!("sse2") {
367            unsafe { soft_threshold_sse2(values, threshold, output) };
368            return;
369        }
370    }
371
372    soft_threshold_scalar(values, threshold, output);
373}
374
375fn soft_threshold_scalar(values: &[f32], threshold: f32, output: &mut [f32]) {
376    for i in 0..values.len() {
377        let abs_val = values[i].abs();
378        if abs_val <= threshold {
379            output[i] = 0.0;
380        } else {
381            output[i] = values[i].signum() * (abs_val - threshold);
382        }
383    }
384}
385
386#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
387#[target_feature(enable = "sse2")]
388unsafe fn soft_threshold_sse2(values: &[f32], threshold: f32, output: &mut [f32]) {
389    use core::arch::x86_64::*;
390
391    let threshold_vec = _mm_set1_ps(threshold);
392    let zero = _mm_setzero_ps();
393    let one = _mm_set1_ps(1.0);
394    let neg_one = _mm_set1_ps(-1.0);
395    let sign_mask = _mm_set1_ps(-0.0f32);
396    let mut i = 0;
397
398    while i + 4 <= values.len() {
399        let val_vec = _mm_loadu_ps(values.as_ptr().add(i));
400        let abs_val = _mm_andnot_ps(sign_mask, val_vec);
401
402        // Check if |x| > threshold
403        let mask = _mm_cmpgt_ps(abs_val, threshold_vec);
404
405        // Compute sign
406        let pos_mask = _mm_cmpgt_ps(val_vec, zero);
407        let neg_mask = _mm_cmplt_ps(val_vec, zero);
408        let sign = _mm_add_ps(_mm_and_ps(pos_mask, one), _mm_and_ps(neg_mask, neg_one));
409
410        // Compute soft thresholding: sign * max(|x| - threshold, 0)
411        let thresholded = _mm_sub_ps(abs_val, threshold_vec);
412        let result = _mm_mul_ps(sign, thresholded);
413
414        // Apply mask: 0 if |x| <= threshold, result otherwise
415        let final_result = _mm_and_ps(mask, result);
416
417        _mm_storeu_ps(output.as_mut_ptr().add(i), final_result);
418        i += 4;
419    }
420
421    while i < values.len() {
422        let abs_val = values[i].abs();
423        output[i] = if abs_val <= threshold {
424            0.0
425        } else {
426            values[i].signum() * (abs_val - threshold)
427        };
428        i += 1;
429    }
430}
431
432#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
433#[target_feature(enable = "avx2")]
434unsafe fn soft_threshold_avx2(values: &[f32], threshold: f32, output: &mut [f32]) {
435    use core::arch::x86_64::*;
436
437    let threshold_vec = _mm256_set1_ps(threshold);
438    let zero = _mm256_setzero_ps();
439    let one = _mm256_set1_ps(1.0);
440    let neg_one = _mm256_set1_ps(-1.0);
441    let sign_mask = _mm256_set1_ps(-0.0f32);
442    let mut i = 0;
443
444    while i + 8 <= values.len() {
445        let val_vec = _mm256_loadu_ps(values.as_ptr().add(i));
446        let abs_val = _mm256_andnot_ps(sign_mask, val_vec);
447
448        // Check if |x| > threshold
449        let mask = _mm256_cmp_ps(abs_val, threshold_vec, _CMP_GT_OQ);
450
451        // Compute sign
452        let pos_mask = _mm256_cmp_ps(val_vec, zero, _CMP_GT_OQ);
453        let neg_mask = _mm256_cmp_ps(val_vec, zero, _CMP_LT_OQ);
454        let sign = _mm256_add_ps(
455            _mm256_and_ps(pos_mask, one),
456            _mm256_and_ps(neg_mask, neg_one),
457        );
458
459        // Compute soft thresholding: sign * max(|x| - threshold, 0)
460        let thresholded = _mm256_sub_ps(abs_val, threshold_vec);
461        let result = _mm256_mul_ps(sign, thresholded);
462
463        // Apply mask: 0 if |x| <= threshold, result otherwise
464        let final_result = _mm256_and_ps(mask, result);
465
466        _mm256_storeu_ps(output.as_mut_ptr().add(i), final_result);
467        i += 8;
468    }
469
470    while i < values.len() {
471        let abs_val = values[i].abs();
472        output[i] = if abs_val <= threshold {
473            0.0
474        } else {
475            values[i].signum() * (abs_val - threshold)
476        };
477        i += 1;
478    }
479}
480
481/// SIMD-optimized prediction for linear models
482/// Computes y = X * beta
483pub fn linear_predict(x: &[&[f32]], weights: &[f32], output: &mut [f32]) {
484    let n_samples = x.len();
485    let n_features = if n_samples > 0 { x[0].len() } else { 0 };
486
487    assert_eq!(
488        weights.len(),
489        n_features,
490        "Weight length must match number of features"
491    );
492    assert_eq!(
493        output.len(),
494        n_samples,
495        "Output length must match number of samples"
496    );
497
498    if n_samples == 0 {
499        return;
500    }
501
502    #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
503    {
504        if crate::simd_feature_detected!("avx2") {
505            unsafe { linear_predict_avx2(x, weights, output) };
506            return;
507        } else if crate::simd_feature_detected!("sse2") {
508            unsafe { linear_predict_sse2(x, weights, output) };
509            return;
510        }
511    }
512
513    linear_predict_scalar(x, weights, output);
514}
515
516fn linear_predict_scalar(x: &[&[f32]], weights: &[f32], output: &mut [f32]) {
517    let n_samples = x.len();
518    let n_features = weights.len();
519
520    for i in 0..n_samples {
521        let mut sum = 0.0;
522        for j in 0..n_features {
523            sum += x[i][j] * weights[j];
524        }
525        output[i] = sum;
526    }
527}
528
529#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
530#[target_feature(enable = "sse2")]
531unsafe fn linear_predict_sse2(x: &[&[f32]], weights: &[f32], output: &mut [f32]) {
532    use core::arch::x86_64::*;
533
534    let n_samples = x.len();
535    let n_features = weights.len();
536
537    for i in 0..n_samples {
538        let mut sum = _mm_setzero_ps();
539        let mut j = 0;
540
541        while j + 4 <= n_features {
542            let x_vec = _mm_loadu_ps(&x[i][j]);
543            let w_vec = _mm_loadu_ps(&weights[j]);
544            let prod = _mm_mul_ps(x_vec, w_vec);
545            sum = _mm_add_ps(sum, prod);
546            j += 4;
547        }
548
549        let mut result = [0.0f32; 4];
550        _mm_storeu_ps(result.as_mut_ptr(), sum);
551        let mut scalar_sum = result[0] + result[1] + result[2] + result[3];
552
553        while j < n_features {
554            scalar_sum += x[i][j] * weights[j];
555            j += 1;
556        }
557
558        output[i] = scalar_sum;
559    }
560}
561
562#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
563#[target_feature(enable = "avx2")]
564unsafe fn linear_predict_avx2(x: &[&[f32]], weights: &[f32], output: &mut [f32]) {
565    use core::arch::x86_64::*;
566
567    let n_samples = x.len();
568    let n_features = weights.len();
569
570    for i in 0..n_samples {
571        let mut sum = _mm256_setzero_ps();
572        let mut j = 0;
573
574        while j + 8 <= n_features {
575            let x_vec = _mm256_loadu_ps(&x[i][j]);
576            let w_vec = _mm256_loadu_ps(&weights[j]);
577            let prod = _mm256_mul_ps(x_vec, w_vec);
578            sum = _mm256_add_ps(sum, prod);
579            j += 8;
580        }
581
582        let mut result = [0.0f32; 8];
583        _mm256_storeu_ps(result.as_mut_ptr(), sum);
584        let mut scalar_sum = result.iter().sum::<f32>();
585
586        while j < n_features {
587            scalar_sum += x[i][j] * weights[j];
588            j += 1;
589        }
590
591        output[i] = scalar_sum;
592    }
593}
594
595#[allow(non_snake_case)]
596#[cfg(all(test, not(feature = "no-std")))]
597mod tests {
598    use super::*;
599    use approx::assert_relative_eq;
600
601    #[test]
602    fn test_least_squares_normal_equation() {
603        // Simple 2x2 design matrix
604        let x1 = [1.0, 2.0];
605        let x2 = [3.0, 4.0];
606        let x = vec![&x1[..], &x2[..]];
607        let y = [5.0, 6.0];
608
609        let (xtx, xty) = least_squares_normal_equation(&x, &y);
610
611        // Expected X^T * X = [[1*1 + 3*3, 1*2 + 3*4], [2*1 + 4*3, 2*2 + 4*4]] = [[10, 14], [14, 20]]
612        assert_relative_eq!(xtx[0][0], 10.0, epsilon = 1e-6);
613        assert_relative_eq!(xtx[0][1], 14.0, epsilon = 1e-6);
614        assert_relative_eq!(xtx[1][0], 14.0, epsilon = 1e-6);
615        assert_relative_eq!(xtx[1][1], 20.0, epsilon = 1e-6);
616
617        // Expected X^T * y = [1*5 + 3*6, 2*5 + 4*6] = [23, 34]
618        assert_relative_eq!(xty[0], 23.0, epsilon = 1e-6);
619        assert_relative_eq!(xty[1], 34.0, epsilon = 1e-6);
620    }
621
622    #[test]
623    fn test_ridge_regression_normal_equation() {
624        let x1 = [1.0, 2.0];
625        let x2 = [3.0, 4.0];
626        let x = vec![&x1[..], &x2[..]];
627        let y = [5.0, 6.0];
628        let alpha = 1.0;
629
630        let (xtx, _) = ridge_regression_normal_equation(&x, &y, alpha);
631
632        // Expected: X^T * X + alpha * I = [[10+1, 14], [14, 20+1]] = [[11, 14], [14, 21]]
633        assert_relative_eq!(xtx[0][0], 11.0, epsilon = 1e-6);
634        assert_relative_eq!(xtx[0][1], 14.0, epsilon = 1e-6);
635        assert_relative_eq!(xtx[1][0], 14.0, epsilon = 1e-6);
636        assert_relative_eq!(xtx[1][1], 21.0, epsilon = 1e-6);
637    }
638
639    #[test]
640    fn test_elastic_net_penalty() {
641        let weights = vec![1.0, -2.0, 3.0, -4.0];
642        let alpha = 0.1;
643        let l1_ratio = 0.5;
644
645        let penalty = elastic_net_penalty(&weights, alpha, l1_ratio);
646
647        // Expected: L1 norm = |1| + |-2| + |3| + |-4| = 10
648        // L2 norm squared = 1^2 + 2^2 + 3^2 + 4^2 = 30
649        // Penalty = 0.1 * 0.5 * 10 + 0.5 * 0.1 * 0.5 * 30 = 0.5 + 0.75 = 1.25
650        assert_relative_eq!(penalty, 1.25, epsilon = 1e-6);
651    }
652
653    #[test]
654    fn test_soft_threshold() {
655        let values = vec![3.0, -2.0, 1.0, -0.5, 0.0];
656        let threshold = 1.5;
657        let mut output = vec![0.0; 5];
658
659        soft_threshold(&values, threshold, &mut output);
660
661        // Expected:
662        // 3.0: |3.0| > 1.5, so 1.0 * (3.0 - 1.5) = 1.5
663        // -2.0: |-2.0| > 1.5, so -1.0 * (2.0 - 1.5) = -0.5
664        // 1.0: |1.0| <= 1.5, so 0.0
665        // -0.5: |-0.5| <= 1.5, so 0.0
666        // 0.0: |0.0| <= 1.5, so 0.0
667        assert_relative_eq!(output[0], 1.5, epsilon = 1e-6);
668        assert_relative_eq!(output[1], -0.5, epsilon = 1e-6);
669        assert_relative_eq!(output[2], 0.0, epsilon = 1e-6);
670        assert_relative_eq!(output[3], 0.0, epsilon = 1e-6);
671        assert_relative_eq!(output[4], 0.0, epsilon = 1e-6);
672    }
673
674    #[test]
675    fn test_linear_predict() {
676        let x1 = [1.0, 2.0];
677        let x2 = [3.0, 4.0];
678        let x = vec![&x1[..], &x2[..]];
679        let weights = vec![0.5, 1.0];
680        let mut output = vec![0.0; 2];
681
682        linear_predict(&x, &weights, &mut output);
683
684        // Expected:
685        // Sample 1: 1.0 * 0.5 + 2.0 * 1.0 = 2.5
686        // Sample 2: 3.0 * 0.5 + 4.0 * 1.0 = 5.5
687        assert_relative_eq!(output[0], 2.5, epsilon = 1e-6);
688        assert_relative_eq!(output[1], 5.5, epsilon = 1e-6);
689    }
690}