1use crate::utils::*;
7use scirs2_core::ndarray::{s, Array1, Array2, ArrayView2, Axis};
9use sklears_core::{
10 error::{Result as SklResult, SklearsError},
11 traits::{Estimator, Fit, Predict, Untrained},
12 types::Float,
13};
14
15#[derive(Debug, Clone)]
34pub struct ClassifierChain<S = Untrained> {
35 state: S,
36 order: Option<Vec<usize>>,
37 cv: Option<usize>,
38 random_state: Option<u64>,
39}
40
41impl ClassifierChain<Untrained> {
42 pub fn new() -> Self {
44 Self {
45 state: Untrained,
46 order: None,
47 cv: None,
48 random_state: None,
49 }
50 }
51
52 pub fn order(mut self, order: Vec<usize>) -> Self {
54 self.order = Some(order);
55 self
56 }
57
58 pub fn cv(mut self, cv: usize) -> Self {
60 self.cv = Some(cv);
61 self
62 }
63
64 pub fn random_state(mut self, random_state: u64) -> Self {
66 self.random_state = Some(random_state);
67 self
68 }
69}
70
71impl Default for ClassifierChain<Untrained> {
72 fn default() -> Self {
73 Self::new()
74 }
75}
76
77impl Estimator for ClassifierChain<Untrained> {
78 type Config = ();
79 type Error = SklearsError;
80 type Float = Float;
81
82 fn config(&self) -> &Self::Config {
83 &()
84 }
85}
86
87impl ClassifierChain<Untrained> {
88 pub fn fit_simple(
90 self,
91 X: &ArrayView2<'_, Float>,
92 y: &Array2<i32>,
93 ) -> SklResult<ClassifierChain<ClassifierChainTrained>> {
94 let (n_samples, n_features) = X.dim();
95 let n_labels = y.ncols();
96
97 if n_samples != y.nrows() {
98 return Err(SklearsError::InvalidInput(
99 "X and y must have the same number of samples".to_string(),
100 ));
101 }
102
103 let order = self
105 .order
106 .clone()
107 .unwrap_or_else(|| (0..n_labels).collect());
108
109 if order.len() != n_labels {
110 return Err(SklearsError::InvalidInput(
111 "Chain order must contain all label indices".to_string(),
112 ));
113 }
114
115 let mut models = Vec::new();
117 let mut current_features = X.to_owned();
118
119 for (i, &label_idx) in order.iter().enumerate() {
120 let y_binary = y.column(label_idx).to_owned();
121
122 let model = train_binary_classifier(¤t_features.view(), &y_binary)?;
124 models.push(model);
125
126 if i < order.len() - 1 {
128 let predictions = predict_binary_classifier(¤t_features.view(), &models[i]);
129 let n_current_features = current_features.ncols();
130 let mut new_features = Array2::<Float>::zeros((n_samples, n_current_features + 1));
131
132 new_features
134 .slice_mut(s![.., ..n_current_features])
135 .assign(¤t_features);
136
137 for j in 0..n_samples {
139 new_features[[j, n_current_features]] = predictions[j] as Float;
140 }
141
142 current_features = new_features;
143 }
144 }
145
146 let trained_state = ClassifierChainTrained {
147 models,
148 order,
149 n_features,
150 n_labels,
151 };
152
153 Ok(ClassifierChain {
154 state: trained_state,
155 order: self.order,
156 cv: self.cv,
157 random_state: self.random_state,
158 })
159 }
160}
161
162impl Fit<ArrayView2<'_, Float>, Array2<i32>, ClassifierChainTrained>
163 for ClassifierChain<Untrained>
164{
165 type Fitted = ClassifierChain<ClassifierChainTrained>;
166
167 fn fit(self, X: &ArrayView2<'_, Float>, y: &Array2<i32>) -> SklResult<Self::Fitted> {
168 self.fit_simple(X, y)
169 }
170}
171
172#[derive(Debug, Clone)]
174pub struct ClassifierChainTrained {
175 models: Vec<SimpleBinaryModel>,
176 order: Vec<usize>,
177 n_features: usize,
178 n_labels: usize,
179}
180
181impl Predict<ArrayView2<'_, Float>, Array2<i32>> for ClassifierChain<ClassifierChainTrained> {
182 fn predict(&self, X: &ArrayView2<'_, Float>) -> SklResult<Array2<i32>> {
183 let (n_samples, n_features) = X.dim();
184 if n_features != self.state.n_features {
185 return Err(SklearsError::InvalidInput(
186 "X has different number of features than training data".to_string(),
187 ));
188 }
189
190 let mut predictions = Array2::<i32>::zeros((n_samples, self.state.n_labels));
191 let mut current_features = X.to_owned();
192
193 for (i, &label_idx) in self.state.order.iter().enumerate() {
195 let model = &self.state.models[i];
196 let label_predictions = predict_binary_classifier(¤t_features.view(), model);
197
198 for j in 0..n_samples {
200 predictions[[j, label_idx]] = label_predictions[j];
201 }
202
203 if i < self.state.order.len() - 1 {
205 let n_current_features = current_features.ncols();
206 let mut new_features = Array2::<Float>::zeros((n_samples, n_current_features + 1));
207
208 new_features
210 .slice_mut(s![.., ..n_current_features])
211 .assign(¤t_features);
212
213 for j in 0..n_samples {
215 new_features[[j, n_current_features]] = label_predictions[j] as Float;
216 }
217
218 current_features = new_features;
219 }
220 }
221
222 Ok(predictions)
223 }
224}
225
226impl ClassifierChain<ClassifierChainTrained> {
227 pub fn predict_proba(&self, X: &ArrayView2<'_, Float>) -> SklResult<Array2<Float>> {
229 let (n_samples, n_features) = X.dim();
230 if n_features != self.state.n_features {
231 return Err(SklearsError::InvalidInput(
232 "X has different number of features than training data".to_string(),
233 ));
234 }
235
236 let mut probabilities = Array2::<Float>::zeros((n_samples, self.state.n_labels));
237 let mut current_features = X.to_owned();
238
239 for (i, &label_idx) in self.state.order.iter().enumerate() {
241 let model = &self.state.models[i];
242 let label_probabilities = predict_binary_probabilities(¤t_features.view(), model);
243
244 for j in 0..n_samples {
246 probabilities[[j, label_idx]] = label_probabilities[j];
247 }
248
249 if i < self.state.order.len() - 1 {
251 let label_predictions =
252 label_probabilities.mapv(|p| if p > 0.5 { 1.0 } else { 0.0 });
253 let n_current_features = current_features.ncols();
254 let mut new_features = Array2::<Float>::zeros((n_samples, n_current_features + 1));
255
256 new_features
258 .slice_mut(s![.., ..n_current_features])
259 .assign(¤t_features);
260
261 for j in 0..n_samples {
263 new_features[[j, n_current_features]] = label_predictions[j];
264 }
265
266 current_features = new_features;
267 }
268 }
269
270 Ok(probabilities)
271 }
272
273 pub fn chain_order(&self) -> &[usize] {
275 &self.state.order
276 }
277
278 pub fn n_models(&self) -> usize {
280 self.state.models.len()
281 }
282
283 pub fn n_targets(&self) -> usize {
285 self.state.n_labels
286 }
287
288 pub fn predict_simple(&self, X: &ArrayView2<'_, Float>) -> SklResult<Array2<i32>> {
290 self.predict(X)
291 }
292
293 pub fn predict_monte_carlo(
295 &self,
296 X: &ArrayView2<'_, Float>,
297 n_samples: usize,
298 random_state: Option<u64>,
299 ) -> SklResult<Array2<Float>> {
300 if n_samples == 0 {
301 return Err(SklearsError::InvalidInput(
302 "n_samples must be greater than 0".to_string(),
303 ));
304 }
305 self.predict_proba(X)
307 }
308
309 pub fn predict_monte_carlo_labels(
311 &self,
312 X: &ArrayView2<'_, Float>,
313 n_samples: usize,
314 random_state: Option<u64>,
315 ) -> SklResult<Array2<i32>> {
316 if n_samples == 0 {
317 return Err(SklearsError::InvalidInput(
318 "n_samples must be greater than 0".to_string(),
319 ));
320 }
321 self.predict(X)
323 }
324}
325
326#[derive(Debug, Clone)]
345pub struct RegressorChain<S = Untrained> {
346 state: S,
347 order: Option<Vec<usize>>,
348 cv: Option<usize>,
349 random_state: Option<u64>,
350}
351
352impl RegressorChain<Untrained> {
353 pub fn new() -> Self {
355 Self {
356 state: Untrained,
357 order: None,
358 cv: None,
359 random_state: None,
360 }
361 }
362
363 pub fn order(mut self, order: Vec<usize>) -> Self {
365 self.order = Some(order);
366 self
367 }
368
369 pub fn cv(mut self, cv: usize) -> Self {
371 self.cv = Some(cv);
372 self
373 }
374
375 pub fn random_state(mut self, random_state: u64) -> Self {
377 self.random_state = Some(random_state);
378 self
379 }
380}
381
382impl Default for RegressorChain<Untrained> {
383 fn default() -> Self {
384 Self::new()
385 }
386}
387
388impl Estimator for RegressorChain<Untrained> {
389 type Config = ();
390 type Error = SklearsError;
391 type Float = Float;
392
393 fn config(&self) -> &Self::Config {
394 &()
395 }
396}
397
398impl RegressorChain<Untrained> {
399 pub fn fit_simple(
401 self,
402 X: &ArrayView2<'_, Float>,
403 y: &Array2<Float>,
404 ) -> SklResult<RegressorChain<RegressorChainTrained>> {
405 let (n_samples, n_features) = X.dim();
406 let n_targets = y.ncols();
407
408 if n_samples != y.nrows() {
409 return Err(SklearsError::InvalidInput(
410 "X and y must have the same number of samples".to_string(),
411 ));
412 }
413
414 let order = self
416 .order
417 .clone()
418 .unwrap_or_else(|| (0..n_targets).collect());
419
420 if order.len() != n_targets {
421 return Err(SklearsError::InvalidInput(
422 "Chain order must contain all target indices".to_string(),
423 ));
424 }
425
426 let mut models = Vec::new();
428 let mut current_features = X.to_owned();
429
430 for (i, &target_idx) in order.iter().enumerate() {
431 let y_target = y.column(target_idx).to_owned();
432
433 let model = train_simple_linear_classifier(¤t_features.view(), &y_target)?;
435 models.push(model);
436
437 if i < order.len() - 1 {
439 let predictions = predict_simple_linear(¤t_features.view(), &models[i]);
440 let n_current_features = current_features.ncols();
441 let mut new_features = Array2::<Float>::zeros((n_samples, n_current_features + 1));
442
443 new_features
445 .slice_mut(s![.., ..n_current_features])
446 .assign(¤t_features);
447
448 for j in 0..n_samples {
450 new_features[[j, n_current_features]] = predictions[j];
451 }
452
453 current_features = new_features;
454 }
455 }
456
457 let trained_state = RegressorChainTrained {
458 models,
459 order,
460 n_features,
461 n_targets,
462 };
463
464 Ok(RegressorChain {
465 state: trained_state,
466 order: self.order,
467 cv: self.cv,
468 random_state: self.random_state,
469 })
470 }
471}
472
473impl Fit<ArrayView2<'_, Float>, Array2<Float>, RegressorChainTrained>
474 for RegressorChain<Untrained>
475{
476 type Fitted = RegressorChain<RegressorChainTrained>;
477
478 fn fit(self, X: &ArrayView2<'_, Float>, y: &Array2<Float>) -> SklResult<Self::Fitted> {
479 self.fit_simple(X, y)
480 }
481}
482
483#[derive(Debug, Clone)]
485pub struct RegressorChainTrained {
486 models: Vec<SimpleLinearClassifier>,
487 order: Vec<usize>,
488 n_features: usize,
489 n_targets: usize,
490}
491
492impl Predict<ArrayView2<'_, Float>, Array2<Float>> for RegressorChain<RegressorChainTrained> {
493 fn predict(&self, X: &ArrayView2<'_, Float>) -> SklResult<Array2<Float>> {
494 let (n_samples, n_features) = X.dim();
495 if n_features != self.state.n_features {
496 return Err(SklearsError::InvalidInput(
497 "X has different number of features than training data".to_string(),
498 ));
499 }
500
501 let mut predictions = Array2::<Float>::zeros((n_samples, self.state.n_targets));
502 let mut current_features = X.to_owned();
503
504 for (i, &target_idx) in self.state.order.iter().enumerate() {
506 let model = &self.state.models[i];
507 let target_predictions = predict_simple_linear(¤t_features.view(), model);
508
509 for j in 0..n_samples {
511 predictions[[j, target_idx]] = target_predictions[j];
512 }
513
514 if i < self.state.order.len() - 1 {
516 let n_current_features = current_features.ncols();
517 let mut new_features = Array2::<Float>::zeros((n_samples, n_current_features + 1));
518
519 new_features
521 .slice_mut(s![.., ..n_current_features])
522 .assign(¤t_features);
523
524 for j in 0..n_samples {
526 new_features[[j, n_current_features]] = target_predictions[j];
527 }
528
529 current_features = new_features;
530 }
531 }
532
533 Ok(predictions)
534 }
535}
536
537impl RegressorChain<RegressorChainTrained> {
538 pub fn chain_order(&self) -> &[usize] {
540 &self.state.order
541 }
542
543 pub fn n_models(&self) -> usize {
545 self.state.models.len()
546 }
547
548 pub fn get_model(&self, index: usize) -> Option<&SimpleLinearClassifier> {
550 self.state.models.get(index)
551 }
552
553 pub fn n_targets(&self) -> usize {
555 self.state.n_targets
556 }
557
558 pub fn predict_simple(&self, X: &ArrayView2<'_, Float>) -> SklResult<Array2<Float>> {
560 self.predict(X)
561 }
562}
563
564#[derive(Debug, Clone)]
582pub struct EnsembleOfChains<S = Untrained> {
583 state: S,
584 n_chains: usize,
585 chain_method: ChainMethod,
586 random_state: Option<u64>,
587}
588
589#[derive(Debug, Clone, Copy, PartialEq)]
591pub enum ChainMethod {
592 Random,
594 Fixed,
596 Bootstrap,
598}
599
600impl EnsembleOfChains<Untrained> {
601 pub fn new() -> Self {
603 Self {
604 state: Untrained,
605 n_chains: 10,
606 chain_method: ChainMethod::Random,
607 random_state: None,
608 }
609 }
610
611 pub fn n_chains(mut self, n_chains: usize) -> Self {
613 self.n_chains = n_chains;
614 self
615 }
616
617 pub fn chain_method(mut self, method: ChainMethod) -> Self {
619 self.chain_method = method;
620 self
621 }
622
623 pub fn random_state(mut self, random_state: u64) -> Self {
625 self.random_state = Some(random_state);
626 self
627 }
628}
629
630impl Default for EnsembleOfChains<Untrained> {
631 fn default() -> Self {
632 Self::new()
633 }
634}
635
636impl Estimator for EnsembleOfChains<Untrained> {
637 type Config = ();
638 type Error = SklearsError;
639 type Float = Float;
640
641 fn config(&self) -> &Self::Config {
642 &()
643 }
644}
645
646impl EnsembleOfChains<Untrained> {
647 pub fn fit_simple(
649 self,
650 X: &ArrayView2<'_, Float>,
651 y: &Array2<i32>,
652 ) -> SklResult<EnsembleOfChains<EnsembleOfChainsTrained>> {
653 let (n_samples, n_features) = X.dim();
654 let n_labels = y.ncols();
655
656 if n_samples != y.nrows() {
657 return Err(SklearsError::InvalidInput(
658 "X and y must have the same number of samples".to_string(),
659 ));
660 }
661
662 let mut chains = Vec::new();
663 let mut rng_state = self.random_state.unwrap_or(42);
664
665 for i in 0..self.n_chains {
666 let chain_order = match self.chain_method {
668 ChainMethod::Random => {
669 let mut order: Vec<usize> = (0..n_labels).collect();
670 for j in (1..order.len()).rev() {
672 rng_state = rng_state.wrapping_mul(1664525).wrapping_add(1013904223);
673 let k = (rng_state as usize) % (j + 1);
674 order.swap(j, k);
675 }
676 order
677 }
678 ChainMethod::Fixed => {
679 let mut order: Vec<usize> = (0..n_labels).collect();
681 order.rotate_left(i % n_labels);
682 order
683 }
684 ChainMethod::Bootstrap => {
685 let mut order: Vec<usize> = (0..n_labels).collect();
687 for j in (1..order.len()).rev() {
688 rng_state = rng_state.wrapping_mul(1664525).wrapping_add(1013904223);
689 let k = (rng_state as usize) % (j + 1);
690 order.swap(j, k);
691 }
692 order
693 }
694 };
695
696 let chain = ClassifierChain::new()
698 .order(chain_order)
699 .random_state(rng_state);
700
701 let trained_chain = chain.fit_simple(X, y)?;
702 chains.push(trained_chain);
703
704 rng_state = rng_state.wrapping_add(1);
705 }
706
707 let trained_state = EnsembleOfChainsTrained {
708 chains,
709 n_features,
710 n_labels,
711 };
712
713 Ok(EnsembleOfChains {
714 state: trained_state,
715 n_chains: self.n_chains,
716 chain_method: self.chain_method,
717 random_state: self.random_state,
718 })
719 }
720}
721
722impl Fit<ArrayView2<'_, Float>, Array2<i32>, EnsembleOfChainsTrained>
723 for EnsembleOfChains<Untrained>
724{
725 type Fitted = EnsembleOfChains<EnsembleOfChainsTrained>;
726
727 fn fit(self, X: &ArrayView2<'_, Float>, y: &Array2<i32>) -> SklResult<Self::Fitted> {
728 self.fit_simple(X, y)
729 }
730}
731
732#[derive(Debug, Clone)]
734pub struct EnsembleOfChainsTrained {
735 chains: Vec<ClassifierChain<ClassifierChainTrained>>,
736 n_features: usize,
737 n_labels: usize,
738}
739
740impl Predict<ArrayView2<'_, Float>, Array2<i32>> for EnsembleOfChains<EnsembleOfChainsTrained> {
741 fn predict(&self, X: &ArrayView2<'_, Float>) -> SklResult<Array2<i32>> {
742 let (n_samples, n_features) = X.dim();
743 if n_features != self.state.n_features {
744 return Err(SklearsError::InvalidInput(
745 "X has different number of features than training data".to_string(),
746 ));
747 }
748
749 let mut all_predictions = Vec::new();
751 for chain in &self.state.chains {
752 let predictions = chain.predict(X)?;
753 all_predictions.push(predictions);
754 }
755
756 let mut final_predictions = Array2::<i32>::zeros((n_samples, self.state.n_labels));
758
759 for i in 0..n_samples {
760 for j in 0..self.state.n_labels {
761 let mut votes = 0;
762 for predictions in &all_predictions {
763 votes += predictions[[i, j]];
764 }
765 final_predictions[[i, j]] = if votes > (self.state.chains.len() as i32) / 2 {
767 1
768 } else {
769 0
770 };
771 }
772 }
773
774 Ok(final_predictions)
775 }
776}
777
778impl EnsembleOfChains<EnsembleOfChainsTrained> {
779 pub fn predict_proba(&self, X: &ArrayView2<'_, Float>) -> SklResult<Array2<Float>> {
781 let (n_samples, n_features) = X.dim();
782 if n_features != self.state.n_features {
783 return Err(SklearsError::InvalidInput(
784 "X has different number of features than training data".to_string(),
785 ));
786 }
787
788 let mut all_probabilities = Vec::new();
790 for chain in &self.state.chains {
791 let probabilities = chain.predict_proba(X)?;
792 all_probabilities.push(probabilities);
793 }
794
795 let mut final_probabilities = Array2::<Float>::zeros((n_samples, self.state.n_labels));
797
798 for i in 0..n_samples {
799 for j in 0..self.state.n_labels {
800 let mut prob_sum = 0.0;
801 for probabilities in &all_probabilities {
802 prob_sum += probabilities[[i, j]];
803 }
804 final_probabilities[[i, j]] = prob_sum / self.state.chains.len() as Float;
805 }
806 }
807
808 Ok(final_probabilities)
809 }
810
811 pub fn n_chains(&self) -> usize {
813 self.state.chains.len()
814 }
815
816 pub fn get_chain(&self, index: usize) -> Option<&ClassifierChain<ClassifierChainTrained>> {
818 self.state.chains.get(index)
819 }
820
821 pub fn chain_diversity(&self) -> Float {
823 if self.state.chains.len() < 2 {
824 return 0.0;
825 }
826
827 let mut diversity_sum = 0.0;
828 let mut count = 0;
829
830 for i in 0..self.state.chains.len() {
832 for j in (i + 1)..self.state.chains.len() {
833 let order1 = self.state.chains[i].chain_order();
834 let order2 = self.state.chains[j].chain_order();
835
836 let mut agreements = 0;
838 for k in 0..order1.len() {
839 if order1[k] == order2[k] {
840 agreements += 1;
841 }
842 }
843
844 let similarity = agreements as Float / order1.len() as Float;
845 diversity_sum += 1.0 - similarity;
846 count += 1;
847 }
848 }
849
850 if count > 0 {
851 diversity_sum / count as Float
852 } else {
853 0.0
854 }
855 }
856
857 pub fn n_targets(&self) -> usize {
859 self.state.n_labels
860 }
861
862 pub fn predict_simple(&self, X: &ArrayView2<'_, Float>) -> SklResult<Array2<i32>> {
864 self.predict(X)
865 }
866
867 pub fn predict_proba_simple(&self, X: &ArrayView2<'_, Float>) -> SklResult<Array2<Float>> {
869 self.predict_proba(X)
870 }
871}
872
873#[derive(Debug, Clone)]
894pub struct BayesianClassifierChain<S = Untrained> {
895 state: S,
896 pub order: Option<Vec<usize>>,
898 pub n_samples: usize,
900 pub prior_strength: Float,
902 pub random_state: Option<u64>,
904}
905
906impl BayesianClassifierChain<Untrained> {
907 pub fn new() -> Self {
909 Self {
910 state: Untrained,
911 order: None,
912 n_samples: 100,
913 prior_strength: 1.0,
914 random_state: None,
915 }
916 }
917
918 pub fn order(mut self, order: Vec<usize>) -> Self {
920 self.order = Some(order);
921 self
922 }
923
924 pub fn n_samples(mut self, n_samples: usize) -> Self {
926 self.n_samples = n_samples;
927 self
928 }
929
930 pub fn prior_strength(mut self, prior_strength: Float) -> Self {
932 self.prior_strength = prior_strength;
933 self
934 }
935
936 pub fn random_state(mut self, random_state: u64) -> Self {
938 self.random_state = Some(random_state);
939 self
940 }
941}
942
943impl Default for BayesianClassifierChain<Untrained> {
944 fn default() -> Self {
945 Self::new()
946 }
947}
948
949impl Estimator for BayesianClassifierChain<Untrained> {
950 type Config = ();
951 type Error = SklearsError;
952 type Float = Float;
953
954 fn config(&self) -> &Self::Config {
955 &()
956 }
957}
958
959impl BayesianClassifierChain<Untrained> {
960 #[allow(non_snake_case)]
962 pub fn fit_simple(
963 self,
964 X: &ArrayView2<'_, Float>,
965 y: &Array2<i32>,
966 ) -> SklResult<BayesianClassifierChain<BayesianClassifierChainTrained>> {
967 let (n_samples, n_features) = X.dim();
968 let n_labels = y.ncols();
969
970 if n_samples != y.nrows() {
971 return Err(SklearsError::InvalidInput(
972 "X and y must have the same number of samples".to_string(),
973 ));
974 }
975
976 for &val in y.iter() {
978 if val != 0 && val != 1 {
979 return Err(SklearsError::InvalidInput(
980 "y must contain only binary values (0 or 1)".to_string(),
981 ));
982 }
983 }
984
985 let order = self
987 .order
988 .clone()
989 .unwrap_or_else(|| (0..n_labels).collect());
990
991 if order.len() != n_labels {
992 return Err(SklearsError::InvalidInput(
993 "Chain order must contain all label indices".to_string(),
994 ));
995 }
996
997 let feature_means = X.mean_axis(Axis(0)).unwrap();
999 let feature_stds = X.std_axis(Axis(0), 0.0);
1000 let X_standardized = standardize_features_simple(X, &feature_means, &feature_stds);
1001
1002 let mut bayesian_models = Vec::new();
1004 let mut current_features = X_standardized;
1005
1006 for (i, &label_idx) in order.iter().enumerate() {
1007 let y_binary = y.column(label_idx).to_owned();
1008
1009 let model = train_bayesian_binary_classifier(
1011 ¤t_features,
1012 &y_binary,
1013 self.prior_strength,
1014 )?;
1015 bayesian_models.push(model);
1016
1017 if i < order.len() - 1 {
1019 let predictions =
1020 predict_bayesian_mean(¤t_features.view(), &bayesian_models[i]);
1021 let n_current_features = current_features.ncols();
1022 let mut new_features = Array2::<Float>::zeros((n_samples, n_current_features + 1));
1023
1024 new_features
1026 .slice_mut(s![.., ..n_current_features])
1027 .assign(¤t_features);
1028
1029 for j in 0..n_samples {
1031 new_features[[j, n_current_features]] = predictions[j];
1032 }
1033
1034 current_features = new_features;
1035 }
1036 }
1037
1038 let trained_state = BayesianClassifierChainTrained {
1039 bayesian_models,
1040 order,
1041 n_features,
1042 n_labels,
1043 feature_means,
1044 feature_stds,
1045 };
1046
1047 Ok(BayesianClassifierChain {
1048 state: trained_state,
1049 order: None,
1050 n_samples: self.n_samples,
1051 prior_strength: self.prior_strength,
1052 random_state: self.random_state,
1053 })
1054 }
1055}
1056
1057impl Fit<ArrayView2<'_, Float>, Array2<i32>, BayesianClassifierChainTrained>
1058 for BayesianClassifierChain<Untrained>
1059{
1060 type Fitted = BayesianClassifierChain<BayesianClassifierChainTrained>;
1061
1062 fn fit(self, X: &ArrayView2<'_, Float>, y: &Array2<i32>) -> SklResult<Self::Fitted> {
1063 self.fit_simple(X, y)
1064 }
1065}
1066
1067#[derive(Debug, Clone)]
1069pub struct BayesianClassifierChainTrained {
1070 bayesian_models: Vec<BayesianBinaryModel>,
1071 order: Vec<usize>,
1072 n_features: usize,
1073 n_labels: usize,
1074 feature_means: Array1<Float>,
1075 feature_stds: Array1<Float>,
1076}
1077
1078impl Predict<ArrayView2<'_, Float>, Array2<i32>>
1079 for BayesianClassifierChain<BayesianClassifierChainTrained>
1080{
1081 #[allow(non_snake_case)]
1082 fn predict(&self, X: &ArrayView2<'_, Float>) -> SklResult<Array2<i32>> {
1083 let (n_samples, n_features) = X.dim();
1084 if n_features != self.state.feature_means.len() {
1085 return Err(SklearsError::InvalidInput(
1086 "X has different number of features than training data".to_string(),
1087 ));
1088 }
1089
1090 let X_standardized =
1092 standardize_features_simple(X, &self.state.feature_means, &self.state.feature_stds);
1093
1094 let mut predictions = Array2::<i32>::zeros((n_samples, self.state.n_labels));
1095 let mut current_features = X_standardized;
1096
1097 for (chain_pos, &label_idx) in self.state.order.iter().enumerate() {
1099 let model = &self.state.bayesian_models[chain_pos];
1100
1101 let label_predictions = predict_bayesian_binary(¤t_features.view(), model);
1103
1104 for i in 0..n_samples {
1106 predictions[[i, label_idx]] = if label_predictions[i] > 0.5 { 1 } else { 0 };
1107 }
1108
1109 if chain_pos < self.state.order.len() - 1 {
1111 let mut new_features =
1112 Array2::<Float>::zeros((n_samples, current_features.ncols() + 1));
1113
1114 new_features
1116 .slice_mut(s![.., ..current_features.ncols()])
1117 .assign(¤t_features);
1118
1119 for i in 0..n_samples {
1121 new_features[[i, current_features.ncols()]] =
1122 predictions[[i, label_idx]] as Float;
1123 }
1124
1125 current_features = new_features;
1126 }
1127 }
1128
1129 Ok(predictions)
1130 }
1131}
1132
1133impl BayesianClassifierChain<BayesianClassifierChainTrained> {
1134 #[allow(non_snake_case)]
1136 pub fn predict_uncertainty(&self, X: &ArrayView2<'_, Float>) -> SklResult<Array2<Float>> {
1137 let (n_samples, n_features) = X.dim();
1138 if n_features != self.state.feature_means.len() {
1139 return Err(SklearsError::InvalidInput(
1140 "X has different number of features than training data".to_string(),
1141 ));
1142 }
1143
1144 let X_standardized =
1146 standardize_features_simple(X, &self.state.feature_means, &self.state.feature_stds);
1147
1148 let mut uncertainties = Array2::<Float>::zeros((n_samples, self.state.n_labels));
1149 let mut current_features = X_standardized;
1150
1151 for (chain_pos, &label_idx) in self.state.order.iter().enumerate() {
1153 let model = &self.state.bayesian_models[chain_pos];
1154
1155 let (means, variances) = predict_bayesian_uncertainty(¤t_features.view(), model)?;
1157
1158 for i in 0..n_samples {
1160 uncertainties[[i, label_idx]] = variances[i];
1161 }
1162
1163 if chain_pos < self.state.order.len() - 1 {
1165 let mut new_features =
1166 Array2::<Float>::zeros((n_samples, current_features.ncols() + 1));
1167
1168 new_features
1170 .slice_mut(s![.., ..current_features.ncols()])
1171 .assign(¤t_features);
1172
1173 for i in 0..n_samples {
1175 new_features[[i, current_features.ncols()]] = means[i];
1176 }
1177
1178 current_features = new_features;
1179 }
1180 }
1181
1182 Ok(uncertainties)
1183 }
1184
1185 #[allow(non_snake_case)]
1187 pub fn predict_proba(&self, X: &ArrayView2<'_, Float>) -> SklResult<Array2<Float>> {
1188 let (n_samples, n_features) = X.dim();
1189 if n_features != self.state.feature_means.len() {
1190 return Err(SklearsError::InvalidInput(
1191 "X has different number of features than training data".to_string(),
1192 ));
1193 }
1194
1195 let X_standardized =
1197 standardize_features_simple(X, &self.state.feature_means, &self.state.feature_stds);
1198
1199 let mut probabilities = Array2::<Float>::zeros((n_samples, self.state.n_labels));
1200 let mut current_features = X_standardized;
1201
1202 for (chain_pos, &label_idx) in self.state.order.iter().enumerate() {
1204 let model = &self.state.bayesian_models[chain_pos];
1205
1206 let label_probabilities = predict_bayesian_binary(¤t_features.view(), model);
1208
1209 for i in 0..n_samples {
1211 probabilities[[i, label_idx]] = label_probabilities[i];
1212 }
1213
1214 if chain_pos < self.state.order.len() - 1 {
1216 let mut new_features =
1217 Array2::<Float>::zeros((n_samples, current_features.ncols() + 1));
1218
1219 new_features
1221 .slice_mut(s![.., ..current_features.ncols()])
1222 .assign(¤t_features);
1223
1224 for i in 0..n_samples {
1226 new_features[[i, current_features.ncols()]] = label_probabilities[i];
1227 }
1228
1229 current_features = new_features;
1230 }
1231 }
1232
1233 Ok(probabilities)
1234 }
1235
1236 pub fn chain_order(&self) -> &[usize] {
1238 &self.state.order
1239 }
1240
1241 pub fn n_models(&self) -> usize {
1243 self.state.bayesian_models.len()
1244 }
1245
1246 pub fn model_posterior_stats(
1248 &self,
1249 model_idx: usize,
1250 ) -> Option<(&Array1<Float>, &Array2<Float>)> {
1251 self.state
1252 .bayesian_models
1253 .get(model_idx)
1254 .map(|model| (&model.weight_mean, &model.weight_cov))
1255 }
1256
1257 pub fn order(&self) -> &[usize] {
1259 &self.state.order
1260 }
1261}
1262
1263fn predict_binary_classifier(X: &ArrayView2<Float>, model: &SimpleBinaryModel) -> Array1<i32> {
1267 let raw_scores = X.dot(&model.weights) + model.bias;
1268 raw_scores.mapv(|x| if x > 0.0 { 1 } else { 0 })
1269}