1use scirs2_core::ndarray::{Array1, Array2};
7use sklears_core::{
8 error::{Result, SklearsError},
9 traits::{Fit, Transform},
10 types::Float,
11};
12
13use crate::{Binarizer, LabelBinarizer, NormType, Normalizer};
14
15pub fn scale(
28 x: &Array2<Float>,
29 axis: usize,
30 with_mean: bool,
31 with_std: bool,
32) -> Result<Array2<Float>> {
33 let mut result = x.clone();
35
36 if axis == 0 {
37 for j in 0..x.ncols() {
39 let column = x.column(j);
40 let mean = if with_mean {
41 column.mean().unwrap_or(0.0)
42 } else {
43 0.0
44 };
45 let std = if with_std {
46 column.std(0.0).max(1e-8)
47 } else {
48 1.0
49 };
50
51 for i in 0..x.nrows() {
52 result[[i, j]] = (x[[i, j]] - mean) / std;
53 }
54 }
55 } else if axis == 1 {
56 for i in 0..x.nrows() {
58 let row = x.row(i);
59 let mean = if with_mean {
60 row.mean().unwrap_or(0.0)
61 } else {
62 0.0
63 };
64 let std = if with_std {
65 row.std(0.0).max(1e-8)
66 } else {
67 1.0
68 };
69
70 for j in 0..x.ncols() {
71 result[[i, j]] = (x[[i, j]] - mean) / std;
72 }
73 }
74 } else {
75 return Err(SklearsError::InvalidInput(format!(
76 "axis must be 0 or 1, got {axis}"
77 )));
78 }
79
80 Ok(result)
81}
82
83pub fn normalize(x: &Array2<Float>, norm: NormType, axis: usize) -> Result<Array2<Float>> {
97 if axis == 1 {
98 let normalizer = Normalizer::new().norm(norm);
100 normalizer.transform(x)
101 } else if axis == 0 {
102 let x_t = x.t().to_owned();
104 let normalizer = Normalizer::new().norm(norm);
105 let normalized = normalizer.transform(&x_t)?;
106 Ok(normalized.t().to_owned())
107 } else {
108 Err(SklearsError::InvalidInput(format!(
109 "axis must be 0 or 1, got {axis}"
110 )))
111 }
112}
113
114pub fn binarize(x: &Array2<Float>, threshold: Float) -> Result<Array2<Float>> {
123 let binarizer = Binarizer::new().threshold(threshold);
124 let fitted = binarizer.fit(x, &())?;
125 fitted.transform(x)
126}
127
128pub fn maxabs_scale(x: &Array2<Float>, axis: usize) -> Result<Array2<Float>> {
137 if axis != 0 {
138 return Err(SklearsError::InvalidInput(
139 "maxabs_scale only supports axis=0".to_string(),
140 ));
141 }
142
143 let mut result = x.clone();
144
145 for j in 0..x.ncols() {
146 let column = x.column(j);
147 let max_abs = column.iter().map(|&v| v.abs()).fold(0.0, Float::max);
148
149 if max_abs > 1e-8 {
150 for i in 0..x.nrows() {
151 result[[i, j]] = x[[i, j]] / max_abs;
152 }
153 }
154 }
155
156 Ok(result)
157}
158
159pub fn minmax_scale(
169 x: &Array2<Float>,
170 feature_range: (Float, Float),
171 axis: usize,
172) -> Result<Array2<Float>> {
173 if axis != 0 {
174 return Err(SklearsError::InvalidInput(
175 "minmax_scale only supports axis=0".to_string(),
176 ));
177 }
178
179 let mut result = x.clone();
180 let (min_range, max_range) = feature_range;
181
182 for j in 0..x.ncols() {
183 let column = x.column(j);
184 let min_val = column.iter().fold(Float::INFINITY, |a, &b| a.min(b));
185 let max_val = column.iter().fold(Float::NEG_INFINITY, |a, &b| a.max(b));
186 let range = max_val - min_val;
187
188 if range > 1e-8 {
189 for i in 0..x.nrows() {
190 let normalized = (x[[i, j]] - min_val) / range;
191 result[[i, j]] = normalized * (max_range - min_range) + min_range;
192 }
193 } else {
194 let midpoint = (min_range + max_range) / 2.0;
196 for i in 0..x.nrows() {
197 result[[i, j]] = midpoint;
198 }
199 }
200 }
201
202 Ok(result)
203}
204
205pub fn robust_scale(
217 x: &Array2<Float>,
218 axis: usize,
219 with_centering: bool,
220 with_scaling: bool,
221 quantile_range: (Float, Float),
222) -> Result<Array2<Float>> {
223 if axis != 0 {
224 return Err(SklearsError::InvalidInput(
225 "robust_scale only supports axis=0".to_string(),
226 ));
227 }
228
229 let mut result = x.clone();
230
231 for j in 0..x.ncols() {
232 let mut column: Vec<Float> = x.column(j).to_vec();
233 column.sort_by(|a, b| a.partial_cmp(b).unwrap());
234
235 let n = column.len();
236 let q1_idx = ((n as Float) * quantile_range.0) as usize;
237 let q3_idx = ((n as Float) * quantile_range.1) as usize;
238
239 let q1 = column[q1_idx.min(n - 1)];
240 let q3 = column[q3_idx.min(n - 1)];
241 let median = column[n / 2];
242
243 let center = if with_centering { median } else { 0.0 };
244 let scale = if with_scaling && (q3 - q1) > 1e-8 {
245 q3 - q1
246 } else {
247 1.0
248 };
249
250 for i in 0..x.nrows() {
251 result[[i, j]] = (x[[i, j]] - center) / scale;
252 }
253 }
254
255 Ok(result)
256}
257
258pub fn add_dummy_feature(x: &Array2<Float>, value: Float) -> Result<Array2<Float>> {
317 let n_samples = x.nrows();
318 let n_features = x.ncols();
319
320 let mut x_with_dummy = Array2::zeros((n_samples, n_features + 1));
322
323 x_with_dummy.column_mut(0).fill(value);
325
326 x_with_dummy
328 .slice_mut(scirs2_core::ndarray::s![.., 1..])
329 .assign(x);
330
331 Ok(x_with_dummy)
332}
333
334pub fn label_binarize<T>(y: &Array1<T>, neg_label: i32, pos_label: i32) -> Result<Array2<Float>>
344where
345 T: std::hash::Hash + Eq + Clone + std::fmt::Debug + Ord + Send + Sync,
346{
347 let binarizer = LabelBinarizer::<T>::new()
348 .neg_label(neg_label)
349 .pos_label(pos_label);
350 let fitted = binarizer.fit(y, &())?;
351 fitted.transform(y)
352}
353
354#[allow(non_snake_case)]
355#[cfg(test)]
356mod tests {
357 use super::*;
358 use approx::assert_abs_diff_eq;
359 use scirs2_core::ndarray::{arr1, arr2};
360
361 #[test]
362 fn test_scale() {
363 let x = arr2(&[[0.0, 0.0], [0.0, 0.0], [1.0, 1.0], [1.0, 1.0]]);
364
365 let scaled = scale(&x, 0, true, true).unwrap();
367
368 for j in 0..x.ncols() {
370 let col_mean = scaled.column(j).mean().unwrap();
371 assert_abs_diff_eq!(col_mean, 0.0, epsilon = 1e-10);
372 }
373
374 for j in 0..x.ncols() {
376 let col = scaled.column(j);
377 let std = col.std(0.0);
378 assert_abs_diff_eq!(std, 1.0, epsilon = 1e-10);
379 }
380 }
381
382 #[test]
383 fn test_normalize() {
384 let x = arr2(&[[4.0, 3.0], [1.0, 2.0]]);
385
386 let normalized = normalize(&x, NormType::L2, 1).unwrap();
388
389 for i in 0..x.nrows() {
391 let row = normalized.row(i);
392 let norm = row.dot(&row).sqrt();
393 assert_abs_diff_eq!(norm, 1.0, epsilon = 1e-10);
394 }
395 }
396
397 #[test]
398 fn test_binarize() {
399 let x = arr2(&[[0.5, 1.5], [2.5, 3.5]]);
400
401 let binarized = binarize(&x, 2.0).unwrap();
402
403 assert_eq!(binarized[[0, 0]], 0.0);
404 assert_eq!(binarized[[0, 1]], 0.0);
405 assert_eq!(binarized[[1, 0]], 1.0);
406 assert_eq!(binarized[[1, 1]], 1.0);
407 }
408
409 #[test]
410 fn test_minmax_scale() {
411 let x = arr2(&[[0.0, 0.0], [1.0, 2.0], [2.0, 4.0]]);
412
413 let scaled = minmax_scale(&x, (0.0, 1.0), 0).unwrap();
414
415 for j in 0..x.ncols() {
417 let col = scaled.column(j);
418 let min = col.iter().fold(f64::INFINITY, |a, &b| a.min(b));
419 let max = col.iter().fold(f64::NEG_INFINITY, |a, &b| a.max(b));
420 assert_abs_diff_eq!(min, 0.0, epsilon = 1e-10);
421 assert_abs_diff_eq!(max, 1.0, epsilon = 1e-10);
422 }
423 }
424
425 #[test]
426 fn test_add_dummy_feature() {
427 let x = arr2(&[[1.0, 2.0], [3.0, 4.0]]);
428
429 let x_with_dummy = add_dummy_feature(&x, 1.0).unwrap();
430
431 assert_eq!(x_with_dummy.shape(), &[2, 3]);
432 assert_eq!(x_with_dummy[[0, 0]], 1.0);
433 assert_eq!(x_with_dummy[[1, 0]], 1.0);
434 assert_eq!(x_with_dummy[[0, 1]], 1.0);
435 assert_eq!(x_with_dummy[[0, 2]], 2.0);
436 }
437
438 #[test]
439 fn test_label_binarize() {
440 let y = arr1(&[0, 1, 2, 1, 0]);
441
442 let binarized = label_binarize(&y, 0, 1).unwrap();
443
444 assert_eq!(binarized.shape(), &[5, 3]);
446
447 assert_eq!(binarized[[0, 0]], 1.0); assert_eq!(binarized[[1, 1]], 1.0); assert_eq!(binarized[[2, 2]], 1.0); }
452}