use ndarray::{Array1, Array2};
use crate::error::EmlError;
use crate::symreg::{DiscoveredFormula, SymRegConfig, SymRegEngine};
pub fn symbolic_regression(
features: &Array2<f64>,
targets: &Array1<f64>,
config: &SymRegConfig,
) -> Result<Vec<DiscoveredFormula>, EmlError> {
let n_samples = features.nrows();
let n_features = features.ncols();
if n_samples != targets.len() {
return Err(EmlError::DimensionMismatch(n_samples, targets.len()));
}
if n_samples == 0 {
return Err(EmlError::EmptyData);
}
let inputs: Vec<Vec<f64>> = (0..n_samples).map(|i| features.row(i).to_vec()).collect();
let targets_slice = targets
.as_slice()
.ok_or(EmlError::DimensionMismatch(n_samples, 0))?;
let engine = SymRegEngine::new(config.clone());
engine.discover(&inputs, targets_slice, n_features)
}
pub fn symbolic_regression_multi(
features: &ndarray::Array2<f64>,
targets_matrix: &ndarray::Array2<f64>,
config: &SymRegConfig,
) -> Result<Vec<Vec<DiscoveredFormula>>, EmlError> {
let n_samples = features.nrows();
let n_features = features.ncols();
let n_outputs = targets_matrix.ncols();
if targets_matrix.nrows() != n_samples {
return Err(EmlError::DimensionMismatch(
n_samples,
targets_matrix.nrows(),
));
}
if n_samples == 0 {
return Err(EmlError::EmptyData);
}
let inputs: Vec<Vec<f64>> = (0..n_samples).map(|i| features.row(i).to_vec()).collect();
let targets: Vec<Vec<f64>> = (0..n_outputs)
.map(|j| targets_matrix.column(j).to_vec())
.collect();
let engine = SymRegEngine::new(config.clone());
engine.discover_multi(&inputs, &targets, n_features)
}
pub fn symbolic_regression_with_names(
features: &Array2<f64>,
targets: &Array1<f64>,
feature_names: &[&str],
config: &SymRegConfig,
) -> Result<Vec<DiscoveredFormula>, EmlError> {
if feature_names.len() != features.ncols() {
return Err(EmlError::DimensionMismatch(
features.ncols(),
feature_names.len(),
));
}
let mut formulas = symbolic_regression(features, targets, config)?;
for formula in &mut formulas {
let mut pretty = formula.pretty.clone();
for (i, name) in feature_names.iter().enumerate().rev() {
pretty = pretty.replace(&format!("x{i}"), name);
}
formula.pretty = pretty;
}
Ok(formulas)
}
#[cfg(test)]
mod tests {
use super::*;
use ndarray::{Array2, array};
#[test]
fn test_ndarray_conversion() {
let features =
Array2::from_shape_vec((3, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).expect("test shape");
let targets = array![10.0, 20.0, 30.0];
let config = SymRegConfig::quick();
let result = symbolic_regression(&features, &targets, &config);
assert!(result.is_ok());
}
#[test]
fn test_dimension_mismatch() {
let features =
Array2::from_shape_vec((3, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).expect("test shape");
let targets = array![10.0, 20.0];
let config = SymRegConfig::quick();
let result = symbolic_regression(&features, &targets, &config);
assert!(result.is_err());
match result {
Err(EmlError::DimensionMismatch(3, 2)) => {}
other => panic!("expected DimensionMismatch(3, 2), got {other:?}"),
}
}
#[test]
fn test_feature_names_mismatch() {
let features =
Array2::from_shape_vec((3, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).expect("test shape");
let targets = array![10.0, 20.0, 30.0];
let config = SymRegConfig::quick();
let result = symbolic_regression_with_names(&features, &targets, &["a", "b", "c"], &config);
assert!(result.is_err());
}
#[test]
fn test_symbolic_regression_linear() {
let n = 20;
let mut feat_data = Vec::with_capacity(n);
let mut tgt_data = Vec::with_capacity(n);
for i in 0..n {
let x = (i as f64) * 0.5 + 0.1;
feat_data.push(x);
tgt_data.push(2.0 * x);
}
let features = Array2::from_shape_vec((n, 1), feat_data).expect("test shape");
let targets = Array1::from_vec(tgt_data);
let config = SymRegConfig {
max_depth: 2,
max_iter: 500,
num_restarts: 2,
..SymRegConfig::default()
};
let formulas =
symbolic_regression(&features, &targets, &config).expect("regression should succeed");
assert!(!formulas.is_empty(), "should discover at least one formula");
let best = &formulas[0];
assert!(
best.mse < 50.0,
"best MSE should be reasonable, got {}",
best.mse
);
}
#[test]
fn test_with_names_replaces_vars() {
let features = Array2::from_shape_vec(
(5, 2),
vec![1.0, 2.0, 2.0, 3.0, 3.0, 4.0, 4.0, 5.0, 5.0, 6.0],
)
.expect("test shape");
let targets = array![2.0, 6.0, 12.0, 20.0, 30.0];
let config = SymRegConfig::quick();
let formulas =
symbolic_regression_with_names(&features, &targets, &["mass", "vel"], &config)
.expect("regression should succeed");
for formula in &formulas {
assert!(
!formula.pretty.contains("x0") && !formula.pretty.contains("x1"),
"pretty should use feature names, got: {}",
formula.pretty
);
}
}
}