1use scirs2_core::ndarray::{Array1, Array2, ArrayView1, ArrayView2};
4use sklears_core::{
5 error::{Result as SklResult, SklearsError},
6 traits::Transform,
7 types::Float,
8};
9
10use crate::{PipelinePredictor, PipelineStep};
11
12#[derive(Debug, Clone)]
14pub struct MockTransformer {
15 scale_factor: f64,
16}
17
18impl MockTransformer {
19 #[must_use]
21 pub fn new() -> Self {
22 Self { scale_factor: 1.0 }
23 }
24
25 #[must_use]
27 pub fn with_scale(scale_factor: f64) -> Self {
28 Self { scale_factor }
29 }
30}
31
32impl Default for MockTransformer {
33 fn default() -> Self {
34 Self::new()
35 }
36}
37
38impl Transform<ArrayView2<'_, Float>, Array2<f64>> for MockTransformer {
39 fn transform(&self, x: &ArrayView2<'_, Float>) -> SklResult<Array2<f64>> {
40 let transformed = x.mapv(|val| val * self.scale_factor);
41 Ok(transformed)
42 }
43}
44
45impl PipelineStep for MockTransformer {
46 fn transform(&self, x: &ArrayView2<'_, Float>) -> SklResult<Array2<f64>> {
47 Transform::transform(self, x)
48 }
49
50 fn fit(
51 &mut self,
52 _x: &ArrayView2<'_, Float>,
53 _y: Option<&ArrayView1<'_, Float>>,
54 ) -> SklResult<()> {
55 Ok(())
57 }
58
59 fn clone_step(&self) -> Box<dyn PipelineStep> {
60 Box::new(self.clone())
61 }
62}
63
64impl MockTransformer {
66 pub fn fit(
68 &mut self,
69 _x: &ArrayView2<'_, Float>,
70 _y: Option<&ArrayView1<'_, Float>>,
71 ) -> SklResult<()> {
72 Ok(())
73 }
74}
75
76#[derive(Debug, Clone)]
78pub struct MockPredictor {
79 coefficients: Option<Array1<f64>>,
80 intercept: f64,
81 fitted: bool,
82}
83
84impl MockPredictor {
85 #[must_use]
87 pub fn new() -> Self {
88 Self {
89 coefficients: None,
90 intercept: 0.0,
91 fitted: false,
92 }
93 }
94
95 #[must_use]
97 pub fn with_coefficients(coefficients: Array1<f64>, intercept: f64) -> Self {
98 Self {
99 coefficients: Some(coefficients),
100 intercept,
101 fitted: true,
102 }
103 }
104
105 #[must_use]
107 pub fn is_fitted(&self) -> bool {
108 self.fitted
109 }
110}
111
112impl Default for MockPredictor {
113 fn default() -> Self {
114 Self::new()
115 }
116}
117
118impl PipelinePredictor for MockPredictor {
119 fn predict(&self, x: &ArrayView2<'_, Float>) -> SklResult<Array1<f64>> {
120 if !self.fitted {
121 return Err(SklearsError::NotFitted {
122 operation: "predict".to_string(),
123 });
124 }
125
126 let predictions = if let Some(ref coef) = self.coefficients {
127 let mut predictions = Array1::zeros(x.nrows());
129 for i in 0..x.nrows() {
130 let mut sum = self.intercept;
131 for j in 0..x.ncols().min(coef.len()) {
132 sum += x[[i, j]] * coef[j];
133 }
134 predictions[i] = sum;
135 }
136 predictions
137 } else {
138 let mut predictions = Array1::zeros(x.nrows());
140 for i in 0..x.nrows() {
141 predictions[i] = x.row(i).mapv(|v| v).mean().unwrap_or(0.0) + self.intercept;
142 }
143 predictions
144 };
145
146 Ok(predictions)
147 }
148
149 fn fit(&mut self, x: &ArrayView2<'_, Float>, y: &ArrayView1<'_, Float>) -> SklResult<()> {
150 let n_features = x.ncols();
152 let mut coefficients = Array1::zeros(n_features);
153
154 for j in 0..n_features {
155 coefficients[j] = x.column(j).mapv(|v| v).mean().unwrap_or(0.0) / n_features as f64;
156 }
157
158 self.coefficients = Some(coefficients);
159 self.intercept = y.mapv(|v| v).mean().unwrap_or(0.0);
160 self.fitted = true;
161
162 Ok(())
163 }
164
165 fn clone_predictor(&self) -> Box<dyn PipelinePredictor> {
166 Box::new(self.clone())
167 }
168}