use scirs2_core::ndarray::{Array1, Array2, ArrayBase, ArrayView1, Data, Ix2};
use scirs2_core::numeric::{Float, NumCast};
use scirs2_core::simd_ops::SimdUnifiedOps;
use scirs2_core::validation::{check_not_empty, check_positive};
use crate::error::{Result, TransformError};
pub struct SimdPolynomialFeatures<F: Float + NumCast + SimdUnifiedOps> {
degree: usize,
include_bias: bool,
interaction_only: bool,
_phantom: std::marker::PhantomData<F>,
}
impl<F: Float + NumCast + SimdUnifiedOps> SimdPolynomialFeatures<F> {
pub fn new(degree: usize, include_bias: bool, interactiononly: bool) -> Result<Self> {
if degree == 0 {
return Err(TransformError::InvalidInput(
"Degree must be at least 1".to_string(),
));
}
Ok(SimdPolynomialFeatures {
degree,
include_bias,
interaction_only: interactiononly,
_phantom: std::marker::PhantomData,
})
}
pub fn transform<S>(&self, x: &ArrayBase<S, Ix2>) -> Result<Array2<F>>
where
S: Data<Elem = F>,
{
check_not_empty(x, "x")?;
for &val in x.iter() {
if !val.is_finite() {
return Err(crate::error::TransformError::DataValidationError(
"Data contains non-finite values".to_string(),
));
}
}
let n_samples = x.shape()[0];
let nfeatures = x.shape()[1];
if n_samples == 0 || nfeatures == 0 {
return Err(TransformError::InvalidInput("Empty input data".to_string()));
}
if nfeatures > 1000 {
return Err(TransformError::InvalidInput(
"Too many features for polynomial expansion (>1000)".to_string(),
));
}
let n_outputfeatures = self.calculate_n_outputfeatures(nfeatures)?;
if n_samples > 100_000 && n_outputfeatures > 10_000 {
return Err(TransformError::ComputationError(
"Output matrix would be too large (>1B elements)".to_string(),
));
}
let mut output = Array2::zeros((n_samples, n_outputfeatures));
let batch_size = self.calculate_optimal_batch_size(n_samples, n_outputfeatures);
for batch_start in (0..n_samples).step_by(batch_size) {
let batch_end = (batch_start + batch_size).min(n_samples);
for i in batch_start..batch_end {
let sample = x.row(i);
let poly_features = self.transform_sample_simd(&sample)?;
if poly_features.len() == n_outputfeatures {
let mut output_row = output.row_mut(i);
for (j, &val) in poly_features.iter().enumerate() {
output_row[j] = val;
}
} else {
return Err(TransformError::ComputationError(
"Feature count mismatch in polynomial expansion".to_string(),
));
}
}
}
Ok(output)
}
fn transform_sample_simd(&self, sample: &ArrayView1<F>) -> Result<Array1<F>> {
let nfeatures = sample.len();
let n_outputfeatures = self.calculate_n_outputfeatures(nfeatures)?;
let mut output = Array1::zeros(n_outputfeatures);
let mut output_idx = 0;
if self.include_bias {
output[output_idx] = F::one();
output_idx += 1;
}
for j in 0..nfeatures {
output[output_idx] = sample[j];
output_idx += 1;
}
if self.degree > 1 {
if self.interaction_only {
let _ = self.add_interaction_terms(sample, &mut output, output_idx, 2)?;
} else {
let _ = self.add_polynomial_terms(sample, &mut output, output_idx)?;
}
}
Ok(output)
}
fn add_polynomial_terms(
&self,
sample: &ArrayView1<F>,
output: &mut Array1<F>,
mut output_idx: usize,
) -> Result<usize> {
let nfeatures = sample.len();
if self.degree == 2 {
let squared = F::simd_mul(&sample.view(), &sample.view());
for j in 0..nfeatures {
output[output_idx] = squared[j];
output_idx += 1;
}
for j in 0..nfeatures {
let remaining_features = nfeatures - j - 1;
if remaining_features > 0 {
let sample_j = sample[j];
let remaining_slice = sample.slice(scirs2_core::ndarray::s![j + 1..]);
let sample_j_vec = Array1::from_elem(remaining_features, sample_j);
let cross_products = F::simd_mul(&sample_j_vec.view(), &remaining_slice);
for &val in cross_products.iter() {
output[output_idx] = val;
output_idx += 1;
}
}
}
} else {
for current_degree in 2..=self.degree {
output_idx = self.add_degree_terms(sample, output, output_idx, current_degree)?;
}
}
Ok(output_idx)
}
fn add_interaction_terms(
&self,
sample: &ArrayView1<F>,
output: &mut Array1<F>,
mut output_idx: usize,
degree: usize,
) -> Result<usize> {
let nfeatures = sample.len();
if degree == 2 {
for j in 0..nfeatures {
let remaining_features = nfeatures - j - 1;
if remaining_features > 0 {
let sample_j = sample[j];
let remaining_slice = sample.slice(scirs2_core::ndarray::s![j + 1..]);
let sample_j_vec = Array1::from_elem(remaining_features, sample_j);
let interactions = F::simd_mul(&sample_j_vec.view(), &remaining_slice);
for &val in interactions.iter() {
output[output_idx] = val;
output_idx += 1;
}
} else {
for k in j + 1..nfeatures {
output[output_idx] = sample[j] * sample[k];
output_idx += 1;
}
}
}
} else {
let indices = self.generate_interaction_indices(nfeatures, degree);
for idx_set in indices {
let mut prod = F::one();
for &_idx in &idx_set {
prod = prod * sample[_idx];
}
output[output_idx] = prod;
output_idx += 1;
}
}
Ok(output_idx)
}
fn add_degree_terms(
&self,
sample: &ArrayView1<F>,
output: &mut Array1<F>,
mut output_idx: usize,
degree: usize,
) -> Result<usize> {
let nfeatures = sample.len();
let indices = self.generate_degree_indices(nfeatures, degree);
for idx_vec in indices {
let mut prod = F::one();
for &_idx in &idx_vec {
prod = prod * sample[_idx];
}
output[output_idx] = prod;
output_idx += 1;
}
Ok(output_idx)
}
fn calculate_optimal_batch_size(&self, n_samples: usize, n_outputfeatures: usize) -> usize {
const L1_CACHE_SIZE: usize = 32_768;
let element_size = std::mem::size_of::<F>();
let elements_per_batch = L1_CACHE_SIZE / element_size / 2; let max_batch_size = elements_per_batch / n_outputfeatures.max(1);
let optimal_batch_size = if n_outputfeatures > 1000 {
16.max(max_batch_size).min(64)
} else if n_samples > 50_000 {
64.max(max_batch_size).min(256)
} else {
128.max(max_batch_size).min(512)
};
optimal_batch_size.min(n_samples)
}
fn calculate_n_outputfeatures(&self, nfeatures: usize) -> Result<usize> {
let mut count = if self.include_bias { 1 } else { 0 };
count += nfeatures;
if self.degree > 1 {
if self.interaction_only {
for d in 2..=self.degree {
count += self.n_choose_k(nfeatures, d);
}
} else {
count += self.n_polynomial_features(nfeatures, self.degree) - nfeatures;
if self.include_bias {
count -= 1;
}
}
}
Ok(count)
}
fn n_choose_k(&self, n: usize, k: usize) -> usize {
if k > n {
return 0;
}
if k == 0 || k == n {
return 1;
}
let mut result = 1;
for i in 0..k {
result = result * (n - i) / (i + 1);
}
result
}
fn n_polynomial_features(&self, nfeatures: usize, degree: usize) -> usize {
self.n_choose_k(nfeatures + degree, degree)
}
fn generate_interaction_indices(&self, nfeatures: usize, degree: usize) -> Vec<Vec<usize>> {
let mut indices = Vec::new();
let mut current = vec![0; degree];
loop {
indices.push(current.clone());
let mut i = degree - 1;
loop {
current[i] += 1;
if current[i] < nfeatures - (degree - 1 - i) {
for j in i + 1..degree {
current[j] = current[j - 1] + 1;
}
break;
}
if i == 0 {
return indices;
}
i -= 1;
}
}
}
fn generate_degree_indices(&self, nfeatures: usize, degree: usize) -> Vec<Vec<usize>> {
let mut indices = Vec::new();
let mut current = vec![0; degree];
loop {
indices.push(current.clone());
let mut i = degree - 1;
loop {
current[i] += 1;
if current[i] < nfeatures {
for j in i + 1..degree {
current[j] = current[i];
}
break;
}
if i == 0 {
return indices;
}
current[i] = 0;
i -= 1;
}
}
}
}
#[allow(dead_code)]
pub fn simd_power_transform<F>(data: &Array1<F>, lambda: F, method: &str) -> Result<Array1<F>>
where
F: Float + NumCast + SimdUnifiedOps,
{
let n = data.len();
let mut result = Array1::zeros(n);
match method {
"box-cox" => {
let min_val = F::simd_min_element(&data.view());
if min_val <= F::zero() {
return Err(TransformError::InvalidInput(
"Box-Cox requires strictly positive values".to_string(),
));
}
if lambda.abs() < F::from(1e-6).expect("Failed to convert constant to float") {
for i in 0..n {
result[i] = data[i].ln();
}
} else {
let ones = Array1::from_elem(n, F::one());
let powered = simd_array_pow(data, lambda)?;
let numerator = F::simd_sub(&powered.view(), &ones.view());
let lambda_array = Array1::from_elem(n, lambda);
result = F::simd_div(&numerator.view(), &lambda_array.view());
}
}
"yeo-johnson" => {
for i in 0..n {
let x = data[i];
if x >= F::zero() {
if lambda.abs() < F::from(1e-6).expect("Failed to convert constant to float") {
result[i] = x.ln() + F::one();
} else {
result[i] = ((x + F::one()).powf(lambda) - F::one()) / lambda;
}
} else {
if (F::from(2.0).expect("Failed to convert constant to float") - lambda).abs()
< F::from(1e-6).expect("Failed to convert constant to float")
{
result[i] = -((-x + F::one()).ln());
} else {
result[i] = -((-x + F::one()).powf(
F::from(2.0).expect("Failed to convert constant to float") - lambda,
) - F::one())
/ (F::from(2.0).expect("Failed to convert constant to float") - lambda);
}
}
}
}
_ => {
return Err(TransformError::InvalidInput(
"Method must be 'box-cox' or 'yeo-johnson'".to_string(),
));
}
}
Ok(result)
}
#[allow(dead_code)]
fn simd_array_pow<F>(data: &Array1<F>, exponent: F) -> Result<Array1<F>>
where
F: Float + NumCast + SimdUnifiedOps,
{
let n = data.len();
if n == 0 {
return Ok(Array1::zeros(0));
}
if !exponent.is_finite() {
return Err(TransformError::InvalidInput(
"Exponent must be finite".to_string(),
));
}
let mut result = Array1::zeros(n);
if (exponent - F::from(2.0).expect("Failed to convert constant to float")).abs()
< F::from(1e-10).expect("Failed to convert constant to float")
{
result = F::simd_mul(&data.view(), &data.view());
} else if (exponent - F::from(0.5).expect("Failed to convert constant to float")).abs()
< F::from(1e-10).expect("Failed to convert constant to float")
{
for &val in data.iter() {
if val < F::zero() {
return Err(TransformError::ComputationError(
"Cannot compute square root of negative values".to_string(),
));
}
}
result = F::simd_sqrt(&data.view());
} else if (exponent - F::from(3.0).expect("Failed to convert constant to float")).abs()
< F::from(1e-10).expect("Failed to convert constant to float")
{
let squared = F::simd_mul(&data.view(), &data.view());
result = F::simd_mul(&squared.view(), &data.view());
} else if (exponent - F::from(1.0).expect("Failed to convert constant to float")).abs()
< F::from(1e-10).expect("Failed to convert constant to float")
{
result = data.clone();
} else if (exponent - F::from(0.0).expect("Failed to convert constant to float")).abs()
< F::from(1e-10).expect("Failed to convert constant to float")
{
result.fill(F::one());
} else {
let exponent_array = Array1::from_elem(n, exponent);
result = data.mapv(|x| x.powf(exponent));
for &val in result.iter() {
if !val.is_finite() {
return Err(TransformError::ComputationError(
"Power operation produced non-finite values".to_string(),
));
}
}
}
Ok(result)
}
#[allow(dead_code)]
pub fn simd_binarize<F>(data: &Array2<F>, threshold: F) -> Result<Array2<F>>
where
F: Float + NumCast + SimdUnifiedOps,
{
check_not_empty(data, "data")?;
for &val in data.iter() {
if !val.is_finite() {
return Err(crate::error::TransformError::DataValidationError(
"Data contains non-finite values".to_string(),
));
}
}
if !threshold.is_finite() {
return Err(TransformError::InvalidInput(
"Threshold must be finite".to_string(),
));
}
let shape = data.shape();
let mut result = Array2::zeros((shape[0], shape[1]));
let chunk_size = calculate_adaptive_chunk_size(shape[0], shape[1]);
for i in 0..shape[0] {
let row = data.row(i);
let row_array = row.to_owned();
for chunk_start in (0..shape[1]).step_by(chunk_size) {
let chunk_end = (chunk_start + chunk_size).min(shape[1]);
let chunk_size = chunk_end - chunk_start;
let chunk_slice = row_array.slice(scirs2_core::ndarray::s![chunk_start..chunk_end]);
let threshold_array = Array1::from_elem(chunk_size, threshold);
let comparison_result =
chunk_slice.mapv(|x| if x > threshold { F::one() } else { F::zero() });
for (j, &cmp_result) in comparison_result.iter().enumerate() {
result[[i, chunk_start + j]] = if cmp_result > F::zero() {
F::one()
} else {
F::zero()
};
}
}
}
Ok(result)
}
#[allow(dead_code)]
fn calculate_adaptive_chunk_size(n_rows: usize, ncols: usize) -> usize {
const L1_CACHE_SIZE: usize = 32_768;
const F64_SIZE: usize = 8;
let cache_elements = L1_CACHE_SIZE / F64_SIZE / 4;
let chunk_size = if ncols > cache_elements {
32
} else if n_rows > 10_000 {
128
} else {
256
};
chunk_size.min(ncols).max(16)
}
#[allow(dead_code)]
pub fn simd_polynomial_features_optimized<F>(
data: &Array2<F>,
degree: usize,
include_bias: bool,
interaction_only: bool,
memory_limit_mb: usize,
) -> Result<Array2<F>>
where
F: Float + NumCast + SimdUnifiedOps,
{
check_not_empty(data, "data")?;
for &val in data.iter() {
if !val.is_finite() {
return Err(crate::error::TransformError::DataValidationError(
"Data contains non-finite values".to_string(),
));
}
}
check_positive(degree, "degree")?;
let poly_features = SimdPolynomialFeatures::new(degree, include_bias, interaction_only)?;
let shape = data.shape();
let element_size = std::mem::size_of::<F>();
let data_size_mb = (shape[0] * shape[1] * element_size) / (1024 * 1024);
if data_size_mb > memory_limit_mb {
simd_polynomial_features_chunked(data, &poly_features, memory_limit_mb)
} else {
poly_features.transform(data)
}
}
#[allow(dead_code)]
fn simd_polynomial_features_chunked<F>(
data: &Array2<F>,
poly_features: &SimdPolynomialFeatures<F>,
memory_limit_mb: usize,
) -> Result<Array2<F>>
where
F: Float + NumCast + SimdUnifiedOps,
{
let shape = data.shape();
let element_size = std::mem::size_of::<F>();
let max_rows_per_chunk = (memory_limit_mb * 1024 * 1024) / (shape[1] * element_size * 2);
if max_rows_per_chunk == 0 {
return Err(TransformError::MemoryError(
"Memory limit too small for processing".to_string(),
));
}
let first_chunk_size = max_rows_per_chunk.min(shape[0]);
let first_chunk = data.slice(scirs2_core::ndarray::s![0..first_chunk_size, ..]);
let first_result = poly_features.transform(&first_chunk)?;
let n_outputfeatures = first_result.shape()[1];
let mut output = Array2::zeros((shape[0], n_outputfeatures));
for i in 0..first_chunk_size {
for j in 0..n_outputfeatures {
output[[i, j]] = first_result[[i, j]];
}
}
for chunk_start in (first_chunk_size..shape[0]).step_by(max_rows_per_chunk) {
let chunk_end = (chunk_start + max_rows_per_chunk).min(shape[0]);
let chunk = data.slice(scirs2_core::ndarray::s![chunk_start..chunk_end, ..]);
let chunk_result = poly_features.transform(&chunk)?;
for (i_local, i_global) in (chunk_start..chunk_end).enumerate() {
for j in 0..n_outputfeatures {
output[[i_global, j]] = chunk_result[[i_local, j]];
}
}
}
Ok(output)
}