sklears_ensemble/stacking/
blending.rs1use scirs2_core::ndarray::{Array1, Array2};
7use sklears_core::{
8 error::{Result, SklearsError},
9 prelude::Predict,
10 traits::{Fit, Trained, Untrained},
11 types::Float,
12};
13use std::marker::PhantomData;
14
15#[derive(Debug)]
20pub struct BlendingClassifier<State = Untrained> {
21 pub(crate) holdout_ratio: f64,
22 pub(crate) random_state: Option<u64>,
23 pub(crate) state: PhantomData<State>,
24 pub(crate) n_base_estimators_: Option<usize>,
26 pub(crate) classes_: Option<Array1<i32>>,
27 pub(crate) n_features_in_: Option<usize>,
28}
29
30impl BlendingClassifier<Untrained> {
31 pub fn new(n_base_estimators: usize) -> Self {
33 Self {
34 holdout_ratio: 0.2, random_state: None,
36 state: PhantomData,
37 n_base_estimators_: Some(n_base_estimators),
38 classes_: None,
39 n_features_in_: None,
40 }
41 }
42
43 pub fn holdout_ratio(mut self, ratio: f64) -> Self {
45 if ratio <= 0.0 || ratio >= 1.0 {
46 panic!("Holdout ratio must be between 0 and 1");
47 }
48 self.holdout_ratio = ratio;
49 self
50 }
51
52 pub fn random_state(mut self, random_state: u64) -> Self {
54 self.random_state = Some(random_state);
55 self
56 }
57}
58
59impl Fit<Array2<Float>, Array1<i32>> for BlendingClassifier<Untrained> {
60 type Fitted = BlendingClassifier<Trained>;
61
62 fn fit(self, x: &Array2<Float>, y: &Array1<i32>) -> Result<Self::Fitted> {
63 if x.nrows() != y.len() {
64 return Err(SklearsError::ShapeMismatch {
65 expected: format!("{} samples", x.nrows()),
66 actual: format!("{} samples", y.len()),
67 });
68 }
69
70 let n_features = x.ncols();
71
72 let mut classes: Vec<i32> = y.to_vec();
74 classes.sort_unstable();
75 classes.dedup();
76 let classes_array = Array1::from_vec(classes);
77
78 Ok(BlendingClassifier {
81 holdout_ratio: self.holdout_ratio,
82 random_state: self.random_state,
83 state: PhantomData,
84 n_base_estimators_: self.n_base_estimators_,
85 classes_: Some(classes_array),
86 n_features_in_: Some(n_features),
87 })
88 }
89}
90
91impl Predict<Array2<Float>, Array1<i32>> for BlendingClassifier<Trained> {
92 fn predict(&self, x: &Array2<Float>) -> Result<Array1<i32>> {
93 if x.ncols() != self.n_features_in_.unwrap() {
94 return Err(SklearsError::FeatureMismatch {
95 expected: self.n_features_in_.unwrap(),
96 actual: x.ncols(),
97 });
98 }
99
100 let n_samples = x.nrows();
102 let classes = self.classes_.as_ref().unwrap();
103
104 let predictions = Array1::from_elem(n_samples, classes[0]);
106
107 Ok(predictions)
108 }
109}
110
111impl BlendingClassifier<Trained> {
112 pub fn classes(&self) -> &Array1<i32> {
114 self.classes_.as_ref().unwrap()
115 }
116
117 pub fn n_features_in(&self) -> usize {
119 self.n_features_in_.unwrap()
120 }
121
122 pub fn n_base_estimators(&self) -> usize {
124 self.n_base_estimators_.unwrap()
125 }
126
127 pub fn holdout_ratio(&self) -> f64 {
129 self.holdout_ratio
130 }
131
132 pub fn random_state(&self) -> Option<u64> {
134 self.random_state
135 }
136}
137
138#[allow(non_snake_case)]
139#[cfg(test)]
140mod tests {
141 use super::*;
142 use scirs2_core::ndarray::array;
143
144 #[test]
145 fn test_blending_creation() {
146 let blending = BlendingClassifier::new(2)
147 .holdout_ratio(0.3)
148 .random_state(42);
149
150 assert_eq!(blending.holdout_ratio, 0.3);
151 assert_eq!(blending.random_state, Some(42));
152 assert_eq!(blending.n_base_estimators_.unwrap(), 2);
153 }
154
155 #[test]
156 fn test_blending_fit_predict() {
157 let x = array![
158 [1.0, 2.0],
159 [3.0, 4.0],
160 [5.0, 6.0],
161 [7.0, 8.0],
162 [9.0, 10.0],
163 [11.0, 12.0],
164 [13.0, 14.0],
165 [15.0, 16.0],
166 [17.0, 18.0],
167 [19.0, 20.0],
168 [21.0, 22.0],
169 [23.0, 24.0]
170 ];
171 let y = array![0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1];
172
173 let blending = BlendingClassifier::new(2);
174 let fitted_model = blending.fit(&x, &y).unwrap();
175
176 assert_eq!(fitted_model.n_features_in(), 2);
177 assert_eq!(fitted_model.classes().len(), 2);
178
179 let predictions = fitted_model.predict(&x).unwrap();
180 assert_eq!(predictions.len(), 12);
181 }
182
183 #[test]
184 #[should_panic(expected = "Holdout ratio must be between 0 and 1")]
185 fn test_invalid_holdout_ratio() {
186 let _blending = BlendingClassifier::new(2).holdout_ratio(1.5);
187 }
188
189 #[test]
190 fn test_shape_mismatch() {
191 let x = array![[1.0, 2.0], [3.0, 4.0]];
192 let y = array![0]; let blending = BlendingClassifier::new(1);
195 let result = blending.fit(&x, &y);
196
197 assert!(result.is_err());
198 assert!(result.unwrap_err().to_string().contains("Shape mismatch"));
199 }
200
201 #[test]
202 fn test_feature_mismatch() {
203 let x_train = array![
204 [1.0, 2.0],
205 [3.0, 4.0],
206 [5.0, 6.0],
207 [7.0, 8.0],
208 [9.0, 10.0],
209 [11.0, 12.0]
210 ];
211 let y_train = array![0, 1, 0, 1, 0, 1];
212 let x_test = array![[1.0, 2.0, 3.0]]; let blending = BlendingClassifier::new(1);
215 let fitted_model = blending.fit(&x_train, &y_train).unwrap();
216 let result = fitted_model.predict(&x_test);
217
218 assert!(result.is_err());
219 assert!(result.unwrap_err().to_string().contains("Feature"));
220 }
221}