1use scirs2_core::ndarray::{Array1, Array2, ArrayView1, ArrayView2};
7use sklears_core::{
8 error::{Result, SklearsError},
9 types::Float,
10};
11
12pub fn simd_dot_product(x: &ArrayView1<Float>, weights: &ArrayView1<Float>) -> Float {
14 if x.len() != weights.len() {
15 return 0.0;
16 }
17 x.iter().zip(weights.iter()).map(|(&xi, &wi)| xi * wi).sum()
18}
19
20pub fn simd_linear_prediction(
22 x: &ArrayView1<Float>,
23 weights: &ArrayView1<Float>,
24 intercept: Float,
25) -> Float {
26 simd_dot_product(x, weights) + intercept
27}
28
29pub fn simd_batch_linear_predictions(
31 X: &ArrayView2<Float>,
32 weights: &ArrayView1<Float>,
33 intercept: Float,
34) -> Result<Array1<Float>> {
35 let (n_samples, n_features) = X.dim();
36
37 if weights.len() != n_features {
38 return Err(SklearsError::FeatureMismatch {
39 expected: n_features,
40 actual: weights.len(),
41 });
42 }
43
44 let mut predictions = Array1::<Float>::zeros(n_samples);
45
46 for i in 0..n_samples {
47 let x_sample = X.row(i);
48 predictions[i] = simd_linear_prediction(&x_sample, weights, intercept);
49 }
50
51 Ok(predictions)
52}
53
54pub fn simd_generate_meta_features(
56 X: &ArrayView2<Float>,
57 base_weights: &ArrayView2<Float>,
58 base_intercepts: &ArrayView1<Float>,
59) -> Result<Array2<Float>> {
60 let (n_samples, n_features) = X.dim();
61 let (n_estimators, weight_features) = base_weights.dim();
62
63 if weight_features != n_features {
64 return Err(SklearsError::FeatureMismatch {
65 expected: n_features,
66 actual: weight_features,
67 });
68 }
69
70 if base_intercepts.len() != n_estimators {
71 return Err(SklearsError::InvalidInput(
72 "Number of intercepts must match number of estimators".to_string(),
73 ));
74 }
75
76 let mut meta_features = Array2::<Float>::zeros((n_samples, n_estimators));
77
78 for est_idx in 0..n_estimators {
79 let weights = base_weights.row(est_idx);
80 let intercept = base_intercepts[est_idx];
81
82 for sample_idx in 0..n_samples {
83 let x_sample = X.row(sample_idx);
84 meta_features[[sample_idx, est_idx]] =
85 simd_linear_prediction(&x_sample, &weights, intercept);
86 }
87 }
88
89 Ok(meta_features)
90}
91
92pub fn simd_compute_gradients(
94 X: &ArrayView2<Float>,
95 y: &ArrayView1<Float>,
96 weights: &ArrayView1<Float>,
97 intercept: Float,
98 l2_reg: Float,
99) -> Result<(Array1<Float>, Float)> {
100 let (n_samples, n_features) = X.dim();
101
102 if y.len() != n_samples {
103 return Err(SklearsError::ShapeMismatch {
104 expected: format!("{} samples", n_samples),
105 actual: format!("{} samples", y.len()),
106 });
107 }
108
109 if weights.len() != n_features {
110 return Err(SklearsError::FeatureMismatch {
111 expected: n_features,
112 actual: weights.len(),
113 });
114 }
115
116 let mut grad_weights = Array1::<Float>::zeros(n_features);
117 let mut grad_intercept = 0.0;
118
119 for i in 0..n_samples {
120 let x_i = X.row(i);
121 let y_i = y[i];
122
123 let pred = simd_linear_prediction(&x_i, weights, intercept);
124 let error = pred - y_i;
125
126 grad_intercept += error;
127
128 for j in 0..n_features {
129 grad_weights[j] += error * x_i[j];
130 }
131 }
132
133 let n_samples_f = n_samples as Float;
135 grad_intercept /= n_samples_f;
136
137 for i in 0..n_features {
138 grad_weights[i] = grad_weights[i] / n_samples_f + l2_reg * weights[i];
139 }
140
141 Ok((grad_weights, grad_intercept))
142}
143
144pub fn simd_aggregate_predictions(
146 base_predictions: &ArrayView2<Float>,
147 meta_weights: &ArrayView1<Float>,
148 meta_intercept: Float,
149) -> Result<Array1<Float>> {
150 let (n_samples, n_estimators) = base_predictions.dim();
151
152 if meta_weights.len() != n_estimators {
153 return Err(SklearsError::FeatureMismatch {
154 expected: n_estimators,
155 actual: meta_weights.len(),
156 });
157 }
158
159 let mut final_predictions = Array1::<Float>::zeros(n_samples);
160
161 for i in 0..n_samples {
162 let base_preds = base_predictions.row(i);
163 final_predictions[i] = simd_dot_product(&base_preds, meta_weights) + meta_intercept;
164 }
165
166 Ok(final_predictions)
167}
168
169#[derive(Debug, Clone)]
171pub struct StackingEnsembleModel {
172 pub base_weights: Array2<Float>,
173 pub base_intercepts: Array1<Float>,
174 pub meta_weights: Array1<Float>,
175 pub meta_intercept: Float,
176 pub n_features: usize,
177 pub n_estimators: usize,
178}
179
180impl StackingEnsembleModel {
181 pub fn predict(&self, X: &ArrayView2<Float>) -> Result<Array1<Float>> {
183 let meta_features = simd_generate_meta_features(
184 X,
185 &self.base_weights.view(),
186 &self.base_intercepts.view(),
187 )?;
188
189 simd_aggregate_predictions(
190 &meta_features.view(),
191 &self.meta_weights.view(),
192 self.meta_intercept,
193 )
194 }
195}
196
197pub fn simd_train_stacking_ensemble(
199 X: &ArrayView2<Float>,
200 y: &ArrayView1<Float>,
201 n_base_estimators: usize,
202 learning_rate: Float,
203 l2_reg: Float,
204 n_iterations: usize,
205) -> Result<StackingEnsembleModel> {
206 let (n_samples, n_features) = X.dim();
207
208 if y.len() != n_samples {
209 return Err(SklearsError::ShapeMismatch {
210 expected: format!("{} samples", n_samples),
211 actual: format!("{} samples", y.len()),
212 });
213 }
214
215 let base_weights = Array2::<Float>::zeros((n_base_estimators, n_features));
217 let base_intercepts = Array1::<Float>::zeros(n_base_estimators);
218 let mut meta_weights = Array1::<Float>::zeros(n_base_estimators);
219 let mut meta_intercept = 0.0;
220
221 for _iter in 0..n_iterations {
223 let meta_features =
225 simd_generate_meta_features(X, &base_weights.view(), &base_intercepts.view())?;
226
227 let (grad_weights, grad_intercept) = simd_compute_gradients(
229 &meta_features.view(),
230 y,
231 &meta_weights.view(),
232 meta_intercept,
233 l2_reg,
234 )?;
235
236 for i in 0..n_base_estimators {
238 meta_weights[i] -= learning_rate * grad_weights[i];
239 }
240 meta_intercept -= learning_rate * grad_intercept;
241 }
242
243 Ok(StackingEnsembleModel {
244 base_weights,
245 base_intercepts,
246 meta_weights,
247 meta_intercept,
248 n_features,
249 n_estimators: n_base_estimators,
250 })
251}
252
253fn simd_mean(arr: &ArrayView1<Float>) -> Float {
257 if arr.is_empty() {
258 return 0.0;
259 }
260 arr.sum() / arr.len() as Float
261}
262
263fn simd_variance(arr: &ArrayView1<Float>, mean: Float) -> Float {
265 if arr.len() < 2 {
266 return 0.0;
267 }
268 let sum_sq_diff: Float = arr.iter().map(|&x| (x - mean).powi(2)).sum();
269 sum_sq_diff / (arr.len() - 1) as Float
270}
271
272pub fn simd_compute_ensemble_diversity(predictions: &ArrayView2<Float>) -> Result<Float> {
274 let (n_samples, n_estimators) = predictions.dim();
275
276 if n_estimators < 2 {
277 return Ok(0.0);
278 }
279
280 let mut total_diversity = 0.0;
281 let mut pair_count = 0;
282
283 for i in 0..n_estimators {
285 for j in i + 1..n_estimators {
286 let pred_i = predictions.column(i);
287 let pred_j = predictions.column(j);
288
289 let correlation = simd_correlation_coefficient(&pred_i, &pred_j);
290 let diversity = 1.0 - correlation.abs();
291 total_diversity += diversity;
292 pair_count += 1;
293 }
294 }
295
296 Ok(total_diversity / pair_count as Float)
297}
298
299fn simd_correlation_coefficient(x: &ArrayView1<Float>, y: &ArrayView1<Float>) -> Float {
301 if x.len() != y.len() || x.len() < 2 {
302 return 0.0;
303 }
304
305 let mean_x = simd_mean(x);
306 let mean_y = simd_mean(y);
307
308 let mut sum_xy = 0.0;
309 let mut sum_xx = 0.0;
310 let mut sum_yy = 0.0;
311
312 for i in 0..x.len() {
313 let dx = x[i] - mean_x;
314 let dy = y[i] - mean_y;
315 sum_xy += dx * dy;
316 sum_xx += dx * dx;
317 sum_yy += dy * dy;
318 }
319
320 let denominator = (sum_xx * sum_yy).sqrt();
321 if denominator > 1e-12 {
322 sum_xy / denominator
323 } else {
324 0.0
325 }
326}
327
328#[allow(non_snake_case)]
329#[cfg(test)]
330mod tests {
331 use super::*;
332 use scirs2_core::ndarray::Array1;
333
334 #[test]
335 fn test_simd_dot_product() {
336 let x = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0]);
337 let w = Array1::from_vec(vec![0.1, 0.2, 0.3, 0.4]);
338
339 let result = simd_dot_product(&x.view(), &w.view());
340 let expected = 1.0 * 0.1 + 2.0 * 0.2 + 3.0 * 0.3 + 4.0 * 0.4;
341
342 assert!((result - expected).abs() < 1e-10);
343 }
344
345 #[test]
346 fn test_simd_linear_prediction() {
347 let x = Array1::from_vec(vec![2.0, 3.0]);
348 let w = Array1::from_vec(vec![0.5, 0.3]);
349 let intercept = 1.5;
350
351 let result = simd_linear_prediction(&x.view(), &w.view(), intercept);
352 let expected = 2.0 * 0.5 + 3.0 * 0.3 + 1.5;
353
354 assert!((result - expected).abs() < 1e-10);
355 }
356
357 #[test]
358 fn test_simd_mean() {
359 let data = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0]);
360 let result = simd_mean(&data.view());
361 assert!((result - 3.0).abs() < 1e-10);
362 }
363}