use crate::error::{AprenderError, Result};
use crate::primitives::Matrix;
use crate::traits::Transformer;
use serde::{Deserialize, Serialize};
use std::path::Path;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct StandardScaler {
mean: Option<Vec<f32>>,
std: Option<Vec<f32>>,
with_mean: bool,
with_std: bool,
}
impl Default for StandardScaler {
fn default() -> Self {
Self::new()
}
}
impl StandardScaler {
#[must_use]
pub fn new() -> Self {
Self {
mean: None,
std: None,
with_mean: true,
with_std: true,
}
}
#[must_use]
pub fn with_mean(mut self, with_mean: bool) -> Self {
self.with_mean = with_mean;
self
}
#[must_use]
pub fn with_std(mut self, with_std: bool) -> Self {
self.with_std = with_std;
self
}
#[must_use]
pub fn mean(&self) -> &[f32] {
self.mean
.as_ref()
.expect("Scaler not fitted. Call fit() first.")
}
#[must_use]
pub fn std(&self) -> &[f32] {
self.std
.as_ref()
.expect("Scaler not fitted. Call fit() first.")
}
#[must_use]
pub fn is_fitted(&self) -> bool {
self.mean.is_some()
}
pub fn inverse_transform(&self, x: &Matrix<f32>) -> Result<Matrix<f32>> {
let mean = self
.mean
.as_ref()
.ok_or_else(|| AprenderError::from("Scaler not fitted"))?;
let std = self
.std
.as_ref()
.ok_or_else(|| AprenderError::from("Scaler not fitted"))?;
let (n_samples, n_features) = x.shape();
if n_features != mean.len() {
return Err("Feature dimension mismatch".into());
}
let mut result = vec![0.0; n_samples * n_features];
for i in 0..n_samples {
for j in 0..n_features {
let mut val = x.get(i, j);
if self.with_std && std[j] > 1e-10 {
val *= std[j];
}
if self.with_mean {
val += mean[j];
}
result[i * n_features + j] = val;
}
}
Matrix::from_vec(n_samples, n_features, result).map_err(Into::into)
}
pub fn save_safetensors<P: AsRef<Path>>(&self, path: P) -> std::result::Result<(), String> {
use crate::serialization::safetensors;
use std::collections::BTreeMap;
let mean = self
.mean
.as_ref()
.ok_or_else(|| "Cannot save unfitted scaler. Call fit() first.".to_string())?;
let std = self
.std
.as_ref()
.ok_or_else(|| "Cannot save unfitted scaler. Call fit() first.".to_string())?;
let mut tensors = BTreeMap::new();
tensors.insert("mean".to_string(), (mean.clone(), vec![mean.len()]));
tensors.insert("std".to_string(), (std.clone(), vec![std.len()]));
let with_mean_val = if self.with_mean { 1.0 } else { 0.0 };
tensors.insert("with_mean".to_string(), (vec![with_mean_val], vec![1]));
let with_std_val = if self.with_std { 1.0 } else { 0.0 };
tensors.insert("with_std".to_string(), (vec![with_std_val], vec![1]));
safetensors::save_safetensors(path, &tensors)?;
Ok(())
}
pub fn load_safetensors<P: AsRef<Path>>(path: P) -> std::result::Result<Self, String> {
use crate::serialization::safetensors;
let (metadata, raw_data) = safetensors::load_safetensors(path)?;
let mean_meta = metadata
.get("mean")
.ok_or_else(|| "Missing 'mean' tensor in SafeTensors file".to_string())?;
let mean = safetensors::extract_tensor(&raw_data, mean_meta)?;
let std_meta = metadata
.get("std")
.ok_or_else(|| "Missing 'std' tensor in SafeTensors file".to_string())?;
let std = safetensors::extract_tensor(&raw_data, std_meta)?;
if mean.len() != std.len() {
return Err("Mean and std vectors have different lengths".to_string());
}
let with_mean_meta = metadata
.get("with_mean")
.ok_or_else(|| "Missing 'with_mean' tensor".to_string())?;
let with_mean_data = safetensors::extract_tensor(&raw_data, with_mean_meta)?;
let with_mean = with_mean_data[0] > 0.5;
let with_std_meta = metadata
.get("with_std")
.ok_or_else(|| "Missing 'with_std' tensor".to_string())?;
let with_std_data = safetensors::extract_tensor(&raw_data, with_std_meta)?;
let with_std = with_std_data[0] > 0.5;
Ok(Self {
mean: Some(mean),
std: Some(std),
with_mean,
with_std,
})
}
}
impl Transformer for StandardScaler {
fn fit(&mut self, x: &Matrix<f32>) -> Result<()> {
let (n_samples, n_features) = x.shape();
if n_samples == 0 {
return Err("Cannot fit with zero samples".into());
}
let mut mean = vec![0.0; n_features];
for (j, mean_j) in mean.iter_mut().enumerate() {
let mut sum = 0.0;
for i in 0..n_samples {
sum += x.get(i, j);
}
*mean_j = sum / n_samples as f32;
}
let mut std = vec![0.0; n_features];
for (j, std_j) in std.iter_mut().enumerate() {
let mut sum_sq = 0.0;
for i in 0..n_samples {
let diff = x.get(i, j) - mean[j];
sum_sq += diff * diff;
}
*std_j = (sum_sq / n_samples as f32).sqrt();
}
self.mean = Some(mean);
self.std = Some(std);
Ok(())
}
fn transform(&self, x: &Matrix<f32>) -> Result<Matrix<f32>> {
let mean = self
.mean
.as_ref()
.ok_or_else(|| AprenderError::from("Scaler not fitted"))?;
let std = self
.std
.as_ref()
.ok_or_else(|| AprenderError::from("Scaler not fitted"))?;
let (n_samples, n_features) = x.shape();
if n_features != mean.len() {
return Err("Feature dimension mismatch".into());
}
let mut result = vec![0.0; n_samples * n_features];
for i in 0..n_samples {
for j in 0..n_features {
let mut val = x.get(i, j);
if self.with_mean {
val -= mean[j];
}
if self.with_std && std[j] > 1e-10 {
val /= std[j];
}
result[i * n_features + j] = val;
}
}
Matrix::from_vec(n_samples, n_features, result).map_err(Into::into)
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MinMaxScaler {
data_min: Option<Vec<f32>>,
data_max: Option<Vec<f32>>,
feature_min: f32,
feature_max: f32,
}
impl Default for MinMaxScaler {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RobustScaler {
median: Option<Vec<f32>>,
iqr: Option<Vec<f32>>,
with_centering: bool,
with_scaling: bool,
}
impl Default for RobustScaler {
fn default() -> Self {
Self::new()
}
}
include!("mod_include_01.rs");
#[cfg(test)]
#[path = "tests_normalization_contract.rs"]
mod tests_normalization_contract;