use scirs2_core::ndarray::{Array2, ArrayView2};
use scirs2_symbolic::regression::{discover, DiscoveredFormula, SrConfig};
use std::fmt;
#[derive(Debug, Clone)]
pub struct FormulaExtractionConfig {
pub max_complexity: usize,
pub population_size: usize,
pub n_generations: usize,
pub n_results: usize,
pub seed: u64,
}
impl Default for FormulaExtractionConfig {
fn default() -> Self {
Self {
max_complexity: 10,
population_size: 150,
n_generations: 30,
n_results: 3,
seed: 42,
}
}
}
impl FormulaExtractionConfig {
fn to_sr_config(&self) -> SrConfig {
SrConfig::default()
.with_max_nodes(self.max_complexity)
.with_beam_width(self.population_size)
.with_max_iter(self.n_generations)
.with_top_n(self.n_results)
.with_seed(self.seed)
}
}
#[derive(Debug)]
pub enum FormulaExtractionError {
EmptyGrid,
DimensionMismatch,
SymbolicError(String),
}
impl fmt::Display for FormulaExtractionError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
FormulaExtractionError::EmptyGrid => {
write!(f, "input_grid has zero rows — nothing to extract from")
}
FormulaExtractionError::DimensionMismatch => write!(
f,
"callable returned an array with wrong shape: \
expected (n_samples, 1), check callable output dimensions"
),
FormulaExtractionError::SymbolicError(msg) => {
write!(f, "symbolic regression error: {msg}")
}
}
}
}
impl std::error::Error for FormulaExtractionError {}
pub fn extract_formula_from_callable<F>(
f: F,
input_grid: ArrayView2<'_, f64>,
config: &FormulaExtractionConfig,
) -> Result<Vec<DiscoveredFormula>, FormulaExtractionError>
where
F: Fn(ArrayView2<'_, f64>) -> Array2<f64>,
{
let n_samples = input_grid.shape()[0];
if n_samples == 0 {
return Err(FormulaExtractionError::EmptyGrid);
}
let output = f(input_grid);
if output.shape()[0] != n_samples || output.shape().len() < 2 || output.shape()[1] != 1 {
return Err(FormulaExtractionError::DimensionMismatch);
}
let targets = output.column(0).to_owned();
let sr_cfg = config.to_sr_config();
let results = discover(input_grid, targets.view(), &sr_cfg);
Ok(results)
}
#[cfg(test)]
mod tests {
use super::*;
use scirs2_core::ndarray::Array2;
fn linear_grid(n: usize, scale: f64) -> Array2<f64> {
Array2::from_shape_fn((n, 1), |(i, _)| i as f64 * scale / n as f64)
}
#[test]
fn test_empty_grid_returns_err() {
let grid = Array2::<f64>::zeros((0, 1));
let cfg = FormulaExtractionConfig::default();
let result = extract_formula_from_callable(
|x: ArrayView2<'_, f64>| Array2::zeros((x.shape()[0], 1)),
grid.view(),
&cfg,
);
assert!(
matches!(result, Err(FormulaExtractionError::EmptyGrid)),
"expected EmptyGrid, got {:?}",
result.err().map(|e| e.to_string())
);
}
#[test]
fn test_dim_mismatch_returns_err() {
let grid = Array2::<f64>::zeros((4, 1));
let cfg = FormulaExtractionConfig::default();
let result = extract_formula_from_callable(
|_x: ArrayView2<'_, f64>| Array2::zeros((3, 1)),
grid.view(),
&cfg,
);
assert!(
matches!(result, Err(FormulaExtractionError::DimensionMismatch)),
"expected DimensionMismatch, got {:?}",
result.err().map(|e| e.to_string())
);
}
}