1#[cfg(feature = "no-std")]
7use alloc::{vec, vec::Vec};
8
9pub fn least_squares_normal_equation(
12 x: &[&[f32]], y: &[f32], ) -> (Vec<Vec<f32>>, Vec<f32>) {
15 let n_samples = x.len();
16 let n_features = if n_samples > 0 { x[0].len() } else { 0 };
17
18 assert!(!x.is_empty(), "Design matrix cannot be empty");
19 assert_eq!(
20 y.len(),
21 n_samples,
22 "Target length must match number of samples"
23 );
24
25 let mut xtx = vec![vec![0.0f32; n_features]; n_features];
27 let mut xty = vec![0.0f32; n_features];
29
30 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
31 {
32 if crate::simd_feature_detected!("avx2") {
33 unsafe { least_squares_avx2(x, y, &mut xtx, &mut xty) };
34 return (xtx, xty);
35 } else if crate::simd_feature_detected!("sse2") {
36 unsafe { least_squares_sse2(x, y, &mut xtx, &mut xty) };
37 return (xtx, xty);
38 }
39 }
40
41 least_squares_scalar(x, y, &mut xtx, &mut xty);
42 (xtx, xty)
43}
44
45fn least_squares_scalar(x: &[&[f32]], y: &[f32], xtx: &mut [Vec<f32>], xty: &mut [f32]) {
46 let n_samples = x.len();
47 let n_features = x[0].len();
48
49 for i in 0..n_features {
51 for j in 0..n_features {
52 let sum: f32 = x.iter().map(|row| row[i] * row[j]).sum();
53 xtx[i][j] = sum;
54 }
55 }
56
57 for i in 0..n_features {
59 let mut sum = 0.0;
60 for k in 0..n_samples {
61 sum += x[k][i] * y[k];
62 }
63 xty[i] = sum;
64 }
65}
66
67#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
68#[target_feature(enable = "sse2")]
69unsafe fn least_squares_sse2(x: &[&[f32]], y: &[f32], xtx: &mut [Vec<f32>], xty: &mut [f32]) {
70 use core::arch::x86_64::*;
71
72 let n_samples = x.len();
73 let n_features = x[0].len();
74
75 for i in 0..n_features {
77 for j in 0..n_features {
78 let mut sum = _mm_setzero_ps();
79 let mut k = 0;
80
81 while k + 4 <= n_samples {
82 let xi_vec = _mm_setr_ps(x[k][i], x[k + 1][i], x[k + 2][i], x[k + 3][i]);
83 let xj_vec = _mm_setr_ps(x[k][j], x[k + 1][j], x[k + 2][j], x[k + 3][j]);
84 let prod = _mm_mul_ps(xi_vec, xj_vec);
85 sum = _mm_add_ps(sum, prod);
86 k += 4;
87 }
88
89 let mut result = [0.0f32; 4];
90 _mm_storeu_ps(result.as_mut_ptr(), sum);
91 let mut scalar_sum = result[0] + result[1] + result[2] + result[3];
92
93 while k < n_samples {
94 scalar_sum += x[k][i] * x[k][j];
95 k += 1;
96 }
97
98 xtx[i][j] = scalar_sum;
99 }
100 }
101
102 for i in 0..n_features {
104 let mut sum = _mm_setzero_ps();
105 let mut k = 0;
106
107 while k + 4 <= n_samples {
108 let xi_vec = _mm_setr_ps(x[k][i], x[k + 1][i], x[k + 2][i], x[k + 3][i]);
109 let y_vec = _mm_loadu_ps(&y[k]);
110 let prod = _mm_mul_ps(xi_vec, y_vec);
111 sum = _mm_add_ps(sum, prod);
112 k += 4;
113 }
114
115 let mut result = [0.0f32; 4];
116 _mm_storeu_ps(result.as_mut_ptr(), sum);
117 let mut scalar_sum = result[0] + result[1] + result[2] + result[3];
118
119 while k < n_samples {
120 scalar_sum += x[k][i] * y[k];
121 k += 1;
122 }
123
124 xty[i] = scalar_sum;
125 }
126}
127
128#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
129#[target_feature(enable = "avx2")]
130unsafe fn least_squares_avx2(x: &[&[f32]], y: &[f32], xtx: &mut [Vec<f32>], xty: &mut [f32]) {
131 use core::arch::x86_64::*;
132
133 let n_samples = x.len();
134 let n_features = x[0].len();
135
136 for i in 0..n_features {
138 for j in 0..n_features {
139 let mut sum = _mm256_setzero_ps();
140 let mut k = 0;
141
142 while k + 8 <= n_samples {
143 let xi_vec = _mm256_setr_ps(
144 x[k][i],
145 x[k + 1][i],
146 x[k + 2][i],
147 x[k + 3][i],
148 x[k + 4][i],
149 x[k + 5][i],
150 x[k + 6][i],
151 x[k + 7][i],
152 );
153 let xj_vec = _mm256_setr_ps(
154 x[k][j],
155 x[k + 1][j],
156 x[k + 2][j],
157 x[k + 3][j],
158 x[k + 4][j],
159 x[k + 5][j],
160 x[k + 6][j],
161 x[k + 7][j],
162 );
163 let prod = _mm256_mul_ps(xi_vec, xj_vec);
164 sum = _mm256_add_ps(sum, prod);
165 k += 8;
166 }
167
168 let mut result = [0.0f32; 8];
169 _mm256_storeu_ps(result.as_mut_ptr(), sum);
170 let mut scalar_sum = result.iter().sum::<f32>();
171
172 while k < n_samples {
173 scalar_sum += x[k][i] * x[k][j];
174 k += 1;
175 }
176
177 xtx[i][j] = scalar_sum;
178 }
179 }
180
181 for i in 0..n_features {
183 let mut sum = _mm256_setzero_ps();
184 let mut k = 0;
185
186 while k + 8 <= n_samples {
187 let xi_vec = _mm256_setr_ps(
188 x[k][i],
189 x[k + 1][i],
190 x[k + 2][i],
191 x[k + 3][i],
192 x[k + 4][i],
193 x[k + 5][i],
194 x[k + 6][i],
195 x[k + 7][i],
196 );
197 let y_vec = _mm256_loadu_ps(&y[k]);
198 let prod = _mm256_mul_ps(xi_vec, y_vec);
199 sum = _mm256_add_ps(sum, prod);
200 k += 8;
201 }
202
203 let mut result = [0.0f32; 8];
204 _mm256_storeu_ps(result.as_mut_ptr(), sum);
205 let mut scalar_sum = result.iter().sum::<f32>();
206
207 while k < n_samples {
208 scalar_sum += x[k][i] * y[k];
209 k += 1;
210 }
211
212 xty[i] = scalar_sum;
213 }
214}
215
216pub fn ridge_regression_normal_equation(
219 x: &[&[f32]], y: &[f32], alpha: f32, ) -> (Vec<Vec<f32>>, Vec<f32>) {
223 let (mut xtx, xty) = least_squares_normal_equation(x, y);
224
225 for (i, row) in xtx.iter_mut().enumerate() {
227 row[i] += alpha;
228 }
229
230 (xtx, xty)
231}
232
233pub fn elastic_net_penalty(weights: &[f32], alpha: f32, l1_ratio: f32) -> f32 {
236 assert!(
237 (0.0..=1.0).contains(&l1_ratio),
238 "l1_ratio must be between 0 and 1"
239 );
240 assert!(alpha >= 0.0, "alpha must be non-negative");
241
242 if weights.is_empty() {
243 return 0.0;
244 }
245
246 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
247 {
248 if crate::simd_feature_detected!("avx2") {
249 return unsafe { elastic_net_penalty_avx2(weights, alpha, l1_ratio) };
250 } else if crate::simd_feature_detected!("sse2") {
251 return unsafe { elastic_net_penalty_sse2(weights, alpha, l1_ratio) };
252 }
253 }
254
255 elastic_net_penalty_scalar(weights, alpha, l1_ratio)
256}
257
258fn elastic_net_penalty_scalar(weights: &[f32], alpha: f32, l1_ratio: f32) -> f32 {
259 let l1_norm: f32 = weights.iter().map(|w| w.abs()).sum();
260 let l2_norm_squared: f32 = weights.iter().map(|w| w * w).sum();
261
262 alpha * l1_ratio * l1_norm + 0.5 * alpha * (1.0 - l1_ratio) * l2_norm_squared
263}
264
265#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
266#[target_feature(enable = "sse2")]
267unsafe fn elastic_net_penalty_sse2(weights: &[f32], alpha: f32, l1_ratio: f32) -> f32 {
268 use core::arch::x86_64::*;
269
270 let mut l1_sum = _mm_setzero_ps();
271 let mut l2_sum = _mm_setzero_ps();
272 let sign_mask = _mm_set1_ps(-0.0f32);
273 let mut i = 0;
274
275 while i + 4 <= weights.len() {
276 let w_vec = _mm_loadu_ps(weights.as_ptr().add(i));
277
278 let abs_w = _mm_andnot_ps(sign_mask, w_vec);
280 l1_sum = _mm_add_ps(l1_sum, abs_w);
281
282 let squared_w = _mm_mul_ps(w_vec, w_vec);
284 l2_sum = _mm_add_ps(l2_sum, squared_w);
285
286 i += 4;
287 }
288
289 let mut l1_result = [0.0f32; 4];
290 let mut l2_result = [0.0f32; 4];
291 _mm_storeu_ps(l1_result.as_mut_ptr(), l1_sum);
292 _mm_storeu_ps(l2_result.as_mut_ptr(), l2_sum);
293
294 let mut l1_scalar = l1_result[0] + l1_result[1] + l1_result[2] + l1_result[3];
295 let mut l2_scalar = l2_result[0] + l2_result[1] + l2_result[2] + l2_result[3];
296
297 while i < weights.len() {
298 l1_scalar += weights[i].abs();
299 l2_scalar += weights[i] * weights[i];
300 i += 1;
301 }
302
303 alpha * l1_ratio * l1_scalar + 0.5 * alpha * (1.0 - l1_ratio) * l2_scalar
304}
305
306#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
307#[target_feature(enable = "avx2")]
308unsafe fn elastic_net_penalty_avx2(weights: &[f32], alpha: f32, l1_ratio: f32) -> f32 {
309 use core::arch::x86_64::*;
310
311 let mut l1_sum = _mm256_setzero_ps();
312 let mut l2_sum = _mm256_setzero_ps();
313 let sign_mask = _mm256_set1_ps(-0.0f32);
314 let mut i = 0;
315
316 while i + 8 <= weights.len() {
317 let w_vec = _mm256_loadu_ps(weights.as_ptr().add(i));
318
319 let abs_w = _mm256_andnot_ps(sign_mask, w_vec);
321 l1_sum = _mm256_add_ps(l1_sum, abs_w);
322
323 let squared_w = _mm256_mul_ps(w_vec, w_vec);
325 l2_sum = _mm256_add_ps(l2_sum, squared_w);
326
327 i += 8;
328 }
329
330 let mut l1_result = [0.0f32; 8];
331 let mut l2_result = [0.0f32; 8];
332 _mm256_storeu_ps(l1_result.as_mut_ptr(), l1_sum);
333 _mm256_storeu_ps(l2_result.as_mut_ptr(), l2_sum);
334
335 let mut l1_scalar = l1_result.iter().sum::<f32>();
336 let mut l2_scalar = l2_result.iter().sum::<f32>();
337
338 while i < weights.len() {
339 l1_scalar += weights[i].abs();
340 l2_scalar += weights[i] * weights[i];
341 i += 1;
342 }
343
344 alpha * l1_ratio * l1_scalar + 0.5 * alpha * (1.0 - l1_ratio) * l2_scalar
345}
346
347pub fn soft_threshold(values: &[f32], threshold: f32, output: &mut [f32]) {
350 assert_eq!(
351 values.len(),
352 output.len(),
353 "Arrays must have the same length"
354 );
355 assert!(threshold >= 0.0, "Threshold must be non-negative");
356
357 if values.is_empty() {
358 return;
359 }
360
361 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
362 {
363 if crate::simd_feature_detected!("avx2") {
364 unsafe { soft_threshold_avx2(values, threshold, output) };
365 return;
366 } else if crate::simd_feature_detected!("sse2") {
367 unsafe { soft_threshold_sse2(values, threshold, output) };
368 return;
369 }
370 }
371
372 soft_threshold_scalar(values, threshold, output);
373}
374
375fn soft_threshold_scalar(values: &[f32], threshold: f32, output: &mut [f32]) {
376 for i in 0..values.len() {
377 let abs_val = values[i].abs();
378 if abs_val <= threshold {
379 output[i] = 0.0;
380 } else {
381 output[i] = values[i].signum() * (abs_val - threshold);
382 }
383 }
384}
385
386#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
387#[target_feature(enable = "sse2")]
388unsafe fn soft_threshold_sse2(values: &[f32], threshold: f32, output: &mut [f32]) {
389 use core::arch::x86_64::*;
390
391 let threshold_vec = _mm_set1_ps(threshold);
392 let zero = _mm_setzero_ps();
393 let one = _mm_set1_ps(1.0);
394 let neg_one = _mm_set1_ps(-1.0);
395 let sign_mask = _mm_set1_ps(-0.0f32);
396 let mut i = 0;
397
398 while i + 4 <= values.len() {
399 let val_vec = _mm_loadu_ps(values.as_ptr().add(i));
400 let abs_val = _mm_andnot_ps(sign_mask, val_vec);
401
402 let mask = _mm_cmpgt_ps(abs_val, threshold_vec);
404
405 let pos_mask = _mm_cmpgt_ps(val_vec, zero);
407 let neg_mask = _mm_cmplt_ps(val_vec, zero);
408 let sign = _mm_add_ps(_mm_and_ps(pos_mask, one), _mm_and_ps(neg_mask, neg_one));
409
410 let thresholded = _mm_sub_ps(abs_val, threshold_vec);
412 let result = _mm_mul_ps(sign, thresholded);
413
414 let final_result = _mm_and_ps(mask, result);
416
417 _mm_storeu_ps(output.as_mut_ptr().add(i), final_result);
418 i += 4;
419 }
420
421 while i < values.len() {
422 let abs_val = values[i].abs();
423 output[i] = if abs_val <= threshold {
424 0.0
425 } else {
426 values[i].signum() * (abs_val - threshold)
427 };
428 i += 1;
429 }
430}
431
432#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
433#[target_feature(enable = "avx2")]
434unsafe fn soft_threshold_avx2(values: &[f32], threshold: f32, output: &mut [f32]) {
435 use core::arch::x86_64::*;
436
437 let threshold_vec = _mm256_set1_ps(threshold);
438 let zero = _mm256_setzero_ps();
439 let one = _mm256_set1_ps(1.0);
440 let neg_one = _mm256_set1_ps(-1.0);
441 let sign_mask = _mm256_set1_ps(-0.0f32);
442 let mut i = 0;
443
444 while i + 8 <= values.len() {
445 let val_vec = _mm256_loadu_ps(values.as_ptr().add(i));
446 let abs_val = _mm256_andnot_ps(sign_mask, val_vec);
447
448 let mask = _mm256_cmp_ps(abs_val, threshold_vec, _CMP_GT_OQ);
450
451 let pos_mask = _mm256_cmp_ps(val_vec, zero, _CMP_GT_OQ);
453 let neg_mask = _mm256_cmp_ps(val_vec, zero, _CMP_LT_OQ);
454 let sign = _mm256_add_ps(
455 _mm256_and_ps(pos_mask, one),
456 _mm256_and_ps(neg_mask, neg_one),
457 );
458
459 let thresholded = _mm256_sub_ps(abs_val, threshold_vec);
461 let result = _mm256_mul_ps(sign, thresholded);
462
463 let final_result = _mm256_and_ps(mask, result);
465
466 _mm256_storeu_ps(output.as_mut_ptr().add(i), final_result);
467 i += 8;
468 }
469
470 while i < values.len() {
471 let abs_val = values[i].abs();
472 output[i] = if abs_val <= threshold {
473 0.0
474 } else {
475 values[i].signum() * (abs_val - threshold)
476 };
477 i += 1;
478 }
479}
480
481pub fn linear_predict(x: &[&[f32]], weights: &[f32], output: &mut [f32]) {
484 let n_samples = x.len();
485 let n_features = if n_samples > 0 { x[0].len() } else { 0 };
486
487 assert_eq!(
488 weights.len(),
489 n_features,
490 "Weight length must match number of features"
491 );
492 assert_eq!(
493 output.len(),
494 n_samples,
495 "Output length must match number of samples"
496 );
497
498 if n_samples == 0 {
499 return;
500 }
501
502 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
503 {
504 if crate::simd_feature_detected!("avx2") {
505 unsafe { linear_predict_avx2(x, weights, output) };
506 return;
507 } else if crate::simd_feature_detected!("sse2") {
508 unsafe { linear_predict_sse2(x, weights, output) };
509 return;
510 }
511 }
512
513 linear_predict_scalar(x, weights, output);
514}
515
516fn linear_predict_scalar(x: &[&[f32]], weights: &[f32], output: &mut [f32]) {
517 let n_samples = x.len();
518 let n_features = weights.len();
519
520 for i in 0..n_samples {
521 let mut sum = 0.0;
522 for j in 0..n_features {
523 sum += x[i][j] * weights[j];
524 }
525 output[i] = sum;
526 }
527}
528
529#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
530#[target_feature(enable = "sse2")]
531unsafe fn linear_predict_sse2(x: &[&[f32]], weights: &[f32], output: &mut [f32]) {
532 use core::arch::x86_64::*;
533
534 let n_samples = x.len();
535 let n_features = weights.len();
536
537 for i in 0..n_samples {
538 let mut sum = _mm_setzero_ps();
539 let mut j = 0;
540
541 while j + 4 <= n_features {
542 let x_vec = _mm_loadu_ps(&x[i][j]);
543 let w_vec = _mm_loadu_ps(&weights[j]);
544 let prod = _mm_mul_ps(x_vec, w_vec);
545 sum = _mm_add_ps(sum, prod);
546 j += 4;
547 }
548
549 let mut result = [0.0f32; 4];
550 _mm_storeu_ps(result.as_mut_ptr(), sum);
551 let mut scalar_sum = result[0] + result[1] + result[2] + result[3];
552
553 while j < n_features {
554 scalar_sum += x[i][j] * weights[j];
555 j += 1;
556 }
557
558 output[i] = scalar_sum;
559 }
560}
561
562#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
563#[target_feature(enable = "avx2")]
564unsafe fn linear_predict_avx2(x: &[&[f32]], weights: &[f32], output: &mut [f32]) {
565 use core::arch::x86_64::*;
566
567 let n_samples = x.len();
568 let n_features = weights.len();
569
570 for i in 0..n_samples {
571 let mut sum = _mm256_setzero_ps();
572 let mut j = 0;
573
574 while j + 8 <= n_features {
575 let x_vec = _mm256_loadu_ps(&x[i][j]);
576 let w_vec = _mm256_loadu_ps(&weights[j]);
577 let prod = _mm256_mul_ps(x_vec, w_vec);
578 sum = _mm256_add_ps(sum, prod);
579 j += 8;
580 }
581
582 let mut result = [0.0f32; 8];
583 _mm256_storeu_ps(result.as_mut_ptr(), sum);
584 let mut scalar_sum = result.iter().sum::<f32>();
585
586 while j < n_features {
587 scalar_sum += x[i][j] * weights[j];
588 j += 1;
589 }
590
591 output[i] = scalar_sum;
592 }
593}
594
595#[allow(non_snake_case)]
596#[cfg(all(test, not(feature = "no-std")))]
597mod tests {
598 use super::*;
599 use approx::assert_relative_eq;
600
601 #[test]
602 fn test_least_squares_normal_equation() {
603 let x1 = [1.0, 2.0];
605 let x2 = [3.0, 4.0];
606 let x = vec![&x1[..], &x2[..]];
607 let y = [5.0, 6.0];
608
609 let (xtx, xty) = least_squares_normal_equation(&x, &y);
610
611 assert_relative_eq!(xtx[0][0], 10.0, epsilon = 1e-6);
613 assert_relative_eq!(xtx[0][1], 14.0, epsilon = 1e-6);
614 assert_relative_eq!(xtx[1][0], 14.0, epsilon = 1e-6);
615 assert_relative_eq!(xtx[1][1], 20.0, epsilon = 1e-6);
616
617 assert_relative_eq!(xty[0], 23.0, epsilon = 1e-6);
619 assert_relative_eq!(xty[1], 34.0, epsilon = 1e-6);
620 }
621
622 #[test]
623 fn test_ridge_regression_normal_equation() {
624 let x1 = [1.0, 2.0];
625 let x2 = [3.0, 4.0];
626 let x = vec![&x1[..], &x2[..]];
627 let y = [5.0, 6.0];
628 let alpha = 1.0;
629
630 let (xtx, _) = ridge_regression_normal_equation(&x, &y, alpha);
631
632 assert_relative_eq!(xtx[0][0], 11.0, epsilon = 1e-6);
634 assert_relative_eq!(xtx[0][1], 14.0, epsilon = 1e-6);
635 assert_relative_eq!(xtx[1][0], 14.0, epsilon = 1e-6);
636 assert_relative_eq!(xtx[1][1], 21.0, epsilon = 1e-6);
637 }
638
639 #[test]
640 fn test_elastic_net_penalty() {
641 let weights = vec![1.0, -2.0, 3.0, -4.0];
642 let alpha = 0.1;
643 let l1_ratio = 0.5;
644
645 let penalty = elastic_net_penalty(&weights, alpha, l1_ratio);
646
647 assert_relative_eq!(penalty, 1.25, epsilon = 1e-6);
651 }
652
653 #[test]
654 fn test_soft_threshold() {
655 let values = vec![3.0, -2.0, 1.0, -0.5, 0.0];
656 let threshold = 1.5;
657 let mut output = vec![0.0; 5];
658
659 soft_threshold(&values, threshold, &mut output);
660
661 assert_relative_eq!(output[0], 1.5, epsilon = 1e-6);
668 assert_relative_eq!(output[1], -0.5, epsilon = 1e-6);
669 assert_relative_eq!(output[2], 0.0, epsilon = 1e-6);
670 assert_relative_eq!(output[3], 0.0, epsilon = 1e-6);
671 assert_relative_eq!(output[4], 0.0, epsilon = 1e-6);
672 }
673
674 #[test]
675 fn test_linear_predict() {
676 let x1 = [1.0, 2.0];
677 let x2 = [3.0, 4.0];
678 let x = vec![&x1[..], &x2[..]];
679 let weights = vec![0.5, 1.0];
680 let mut output = vec![0.0; 2];
681
682 linear_predict(&x, &weights, &mut output);
683
684 assert_relative_eq!(output[0], 2.5, epsilon = 1e-6);
688 assert_relative_eq!(output[1], 5.5, epsilon = 1e-6);
689 }
690}