use numpy::{PyReadonlyArray1, PyReadonlyArray2};
use pyo3::exceptions::PyValueError;
use pyo3::prelude::*;
#[pyclass(name = "SymRegConfig", from_py_object)]
#[derive(Clone)]
pub struct PySymRegConfig {
inner: crate::symreg::SymRegConfig,
pub max_formulas: usize,
}
#[pymethods]
impl PySymRegConfig {
#[staticmethod]
pub fn quick() -> Self {
Self {
inner: crate::symreg::SymRegConfig::quick(),
max_formulas: 0,
}
}
#[staticmethod]
pub fn balanced() -> Self {
Self {
inner: crate::symreg::SymRegConfig::balanced(),
max_formulas: 0,
}
}
#[staticmethod]
pub fn exhaustive() -> Self {
Self {
inner: crate::symreg::SymRegConfig::exhaustive(),
max_formulas: 0,
}
}
#[getter]
pub fn depth_limit(&self) -> usize {
self.inner.max_depth
}
#[setter]
pub fn set_depth_limit(&mut self, v: usize) {
self.inner.max_depth = v;
}
#[getter]
pub fn get_max_formulas(&self) -> usize {
self.max_formulas
}
#[setter]
pub fn set_max_formulas(&mut self, v: usize) {
self.max_formulas = v;
}
#[getter]
pub fn adam_steps(&self) -> usize {
self.inner.max_iter
}
#[setter]
pub fn set_adam_steps(&mut self, v: usize) {
self.inner.max_iter = v;
}
#[getter]
pub fn seed(&self) -> Option<u64> {
self.inner.seed
}
#[setter]
pub fn set_seed(&mut self, v: Option<u64>) {
self.inner.seed = v;
}
pub fn __repr__(&self) -> String {
format!(
"SymRegConfig(depth_limit={}, adam_steps={}, max_formulas={})",
self.inner.max_depth, self.inner.max_iter, self.max_formulas
)
}
}
#[pyclass(name = "DiscoveredFormula", from_py_object)]
#[derive(Clone)]
pub struct PyDiscoveredFormula {
inner: crate::symreg::DiscoveredFormula,
}
#[pymethods]
impl PyDiscoveredFormula {
#[getter]
pub fn pretty(&self) -> &str {
&self.inner.pretty
}
#[getter]
pub fn mse(&self) -> f64 {
self.inner.mse
}
#[getter]
pub fn complexity(&self) -> usize {
self.inner.complexity
}
#[getter]
pub fn score(&self) -> f64 {
self.inner.score
}
#[getter]
pub fn cv_mse(&self) -> Option<f64> {
self.inner.cv_mse
}
pub fn to_latex(&self) -> String {
self.inner.to_latex()
}
pub fn eval(&self, xs: Vec<f64>) -> PyResult<f64> {
let lowered = self.inner.eml_tree.lower().simplify();
let n_vars = lowered.count_vars();
if xs.len() < n_vars {
return Err(PyValueError::new_err(format!(
"formula references {} variable(s) but xs has only {} element(s)",
n_vars,
xs.len()
)));
}
Ok(lowered.eval(&xs))
}
pub fn __repr__(&self) -> String {
format!(
"DiscoveredFormula(pretty={:?}, mse={:.6}, complexity={})",
self.inner.pretty, self.inner.mse, self.inner.complexity
)
}
}
#[pyclass(name = "SymRegEngine")]
pub struct PySymRegEngine {
config: PySymRegConfig,
}
#[pymethods]
impl PySymRegEngine {
#[new]
pub fn new(config: &PySymRegConfig) -> Self {
Self {
config: config.clone(),
}
}
pub fn discover<'py>(
&self,
py: Python<'py>,
x: PyReadonlyArray2<'py, f64>,
y: PyReadonlyArray1<'py, f64>,
) -> PyResult<Vec<PyDiscoveredFormula>> {
let x_arr = x.as_array();
let y_arr = y.as_array();
let n_samples = y_arr.len();
let n_features = x_arr.ncols();
if x_arr.nrows() != n_samples {
return Err(PyValueError::new_err(format!(
"X has {} rows but y has {} elements",
x_arr.nrows(),
n_samples
)));
}
if n_samples == 0 {
return Err(PyValueError::new_err("input arrays must not be empty"));
}
let mut inputs: Vec<Vec<f64>> = Vec::with_capacity(n_samples);
for i in 0..n_samples {
let mut row = Vec::with_capacity(n_features);
for j in 0..n_features {
let val = x_arr.get((i, j)).copied().ok_or_else(|| {
PyValueError::new_err(format!("index ({i},{j}) out of bounds"))
})?;
row.push(val);
}
inputs.push(row);
}
let targets: Vec<f64> = y_arr.iter().copied().collect();
let engine = crate::symreg::SymRegEngine::new(self.config.inner.clone());
let max_formulas = self.config.max_formulas;
let result = py.detach(|| engine.discover(&inputs, &targets, n_features));
let mut formulas = result.map_err(|e| PyValueError::new_err(e.to_string()))?;
if max_formulas > 0 && formulas.len() > max_formulas {
formulas.truncate(max_formulas);
}
formulas
.into_iter()
.map(|f| Ok(PyDiscoveredFormula { inner: f }))
.collect()
}
pub fn __repr__(&self) -> String {
format!(
"SymRegEngine(depth_limit={}, adam_steps={})",
self.config.inner.max_depth, self.config.inner.max_iter
)
}
}
#[pymodule]
pub fn _core(_py: Python<'_>, m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_class::<PySymRegConfig>()?;
m.add_class::<PyDiscoveredFormula>()?;
m.add_class::<PySymRegEngine>()?;
Ok(())
}