use super::common::*;
use scirs2_core::ndarray::{Array1, Axis};
#[derive(Debug, Clone)]
struct StandardScalerState {
mean: Array1<f64>,
scale: Array1<f64>,
var: Array1<f64>,
n_features: usize,
n_samples_seen: usize,
}
#[pyclass(name = "StandardScaler")]
pub struct PyStandardScaler {
copy: bool,
with_mean: bool,
with_std: bool,
state: Option<StandardScalerState>,
}
#[pymethods]
impl PyStandardScaler {
#[new]
#[pyo3(signature = (copy=true, with_mean=true, with_std=true))]
fn new(copy: bool, with_mean: bool, with_std: bool) -> Self {
Self {
copy,
with_mean,
with_std,
state: None,
}
}
fn fit(&mut self, x: PyReadonlyArray2<f64>) -> PyResult<()> {
let x_array = pyarray_to_core_array2(&x)?;
validate_fit_array(&x_array)?;
let n_samples = x_array.nrows();
let n_features = x_array.ncols();
let mean = if self.with_mean {
x_array.mean_axis(Axis(0)).expect("array should have elements for mean computation")
} else {
Array1::zeros(n_features)
};
let (var, scale) = if self.with_std {
let mut var = Array1::zeros(n_features);
for j in 0..n_features {
let col = x_array.column(j);
let mean_j = mean[j];
let sum_sq_diff: f64 = col.iter().map(|&x| (x - mean_j).powi(2)).sum();
var[j] = sum_sq_diff / n_samples as f64;
}
let scale = var.mapv(|v| {
let std = v.sqrt();
if std < 1e-10 {
1.0 } else {
std
}
});
(var, scale)
} else {
(Array1::ones(n_features), Array1::ones(n_features))
};
self.state = Some(StandardScalerState {
mean,
scale,
var,
n_features,
n_samples_seen: n_samples,
});
Ok(())
}
fn transform<'py>(&self, py: Python<'py>, x: PyReadonlyArray2<f64>) -> PyResult<Py<PyArray2<f64>>> {
let state = self
.state
.as_ref()
.ok_or_else(|| PyValueError::new_err("Scaler not fitted. Call fit() first."))?;
let x_array = pyarray_to_core_array2(&x)?;
validate_transform_array(&x_array, state.n_features)?;
let mut transformed = x_array.clone();
if self.with_mean {
for j in 0..state.n_features {
for i in 0..transformed.nrows() {
transformed[[i, j]] -= state.mean[j];
}
}
}
if self.with_std {
for j in 0..state.n_features {
for i in 0..transformed.nrows() {
transformed[[i, j]] /= state.scale[j];
}
}
}
core_array2_to_py(py, &transformed)
}
fn fit_transform<'py>(
&mut self,
py: Python<'py>,
x: PyReadonlyArray2<f64>,
) -> PyResult<Py<PyArray2<f64>>> {
let x_array = pyarray_to_core_array2(&x)?;
self.fit(x)?;
let state = self
.state
.as_ref()
.ok_or_else(|| PyValueError::new_err("Scaler not fitted. Call fit() first."))?;
let mut transformed = x_array.clone();
if self.with_mean {
for j in 0..state.n_features {
for i in 0..transformed.nrows() {
transformed[[i, j]] -= state.mean[j];
}
}
}
if self.with_std {
for j in 0..state.n_features {
for i in 0..transformed.nrows() {
transformed[[i, j]] /= state.scale[j];
}
}
}
core_array2_to_py(py, &transformed)
}
fn inverse_transform<'py>(
&self,
py: Python<'py>,
x: PyReadonlyArray2<f64>,
) -> PyResult<Py<PyArray2<f64>>> {
let state = self
.state
.as_ref()
.ok_or_else(|| PyValueError::new_err("Scaler not fitted. Call fit() first."))?;
let x_array = pyarray_to_core_array2(&x)?;
validate_transform_array(&x_array, state.n_features)?;
let mut inverse = x_array.clone();
if self.with_std {
for j in 0..state.n_features {
for i in 0..inverse.nrows() {
inverse[[i, j]] *= state.scale[j];
}
}
}
if self.with_mean {
for j in 0..state.n_features {
for i in 0..inverse.nrows() {
inverse[[i, j]] += state.mean[j];
}
}
}
core_array2_to_py(py, &inverse)
}
#[getter]
fn mean_<'py>(&self, py: Python<'py>) -> PyResult<Py<PyArray1<f64>>> {
let state = self
.state
.as_ref()
.ok_or_else(|| PyValueError::new_err("Scaler not fitted. Call fit() first."))?;
Ok(core_array1_to_py(py, &state.mean))
}
#[getter]
fn scale_<'py>(&self, py: Python<'py>) -> PyResult<Py<PyArray1<f64>>> {
let state = self
.state
.as_ref()
.ok_or_else(|| PyValueError::new_err("Scaler not fitted. Call fit() first."))?;
Ok(core_array1_to_py(py, &state.scale))
}
#[getter]
fn var_<'py>(&self, py: Python<'py>) -> PyResult<Py<PyArray1<f64>>> {
let state = self
.state
.as_ref()
.ok_or_else(|| PyValueError::new_err("Scaler not fitted. Call fit() first."))?;
Ok(core_array1_to_py(py, &state.var))
}
#[getter]
fn n_features_in_(&self) -> PyResult<usize> {
let state = self
.state
.as_ref()
.ok_or_else(|| PyValueError::new_err("Scaler not fitted. Call fit() first."))?;
Ok(state.n_features)
}
#[getter]
fn n_samples_seen_(&self) -> PyResult<usize> {
let state = self
.state
.as_ref()
.ok_or_else(|| PyValueError::new_err("Scaler not fitted. Call fit() first."))?;
Ok(state.n_samples_seen)
}
fn __repr__(&self) -> String {
format!(
"StandardScaler(copy={}, with_mean={}, with_std={})",
self.copy, self.with_mean, self.with_std
)
}
}