use pyo3::exceptions::{PyRuntimeError, PyValueError};
use pyo3::prelude::*;
use scirs2_numpy::{PyReadonlyArray1, PyReadonlyArray2};
use scirs2_symbolic::eml::eval::{eval_real as rust_eval_real, EvalCtx};
use scirs2_symbolic::eml::{
grad as rust_grad, lower as rust_lower, simplify_op as rust_simplify_op,
Canonical as RustCanonical, EmlTree as RustEmlTree, LoweredOp as RustLoweredOp,
};
use scirs2_symbolic::regression::{discover as rust_discover, SrConfig as RustSrConfig};
#[pyclass(name = "EmlTree", module = "scirs2.symbolic", skip_from_py_object)]
#[derive(Clone)]
pub struct PyEmlTree {
inner: RustEmlTree,
}
#[pymethods]
impl PyEmlTree {
#[staticmethod]
fn one() -> Self {
Self {
inner: RustEmlTree::one(),
}
}
#[staticmethod]
fn var(idx: usize) -> Self {
Self {
inner: RustEmlTree::var(idx),
}
}
#[staticmethod]
fn eml(left: &Self, right: &Self) -> Self {
Self {
inner: RustEmlTree::eml(&left.inner, &right.inner),
}
}
fn depth(&self) -> usize {
self.inner.depth()
}
fn size(&self) -> usize {
self.inner.size()
}
fn num_vars(&self) -> usize {
self.inner.num_vars()
}
fn structural_hash(&self) -> (u64, u64) {
let h = self.inner.structural_hash();
((h >> 64) as u64, (h & 0xFFFF_FFFF_FFFF_FFFF) as u64)
}
fn __repr__(&self) -> String {
format!(
"EmlTree(depth={}, size={}, num_vars={})",
self.depth(),
self.size(),
self.num_vars()
)
}
}
#[pyclass(name = "Canonical", module = "scirs2.symbolic")]
pub struct PyCanonical;
#[pymethods]
impl PyCanonical {
#[staticmethod]
fn exp(x: &PyEmlTree) -> PyEmlTree {
PyEmlTree {
inner: RustCanonical::exp(&x.inner),
}
}
#[staticmethod]
fn ln(x: &PyEmlTree) -> PyEmlTree {
PyEmlTree {
inner: RustCanonical::ln(&x.inner),
}
}
#[staticmethod]
fn euler() -> PyEmlTree {
PyEmlTree {
inner: RustCanonical::euler(),
}
}
#[staticmethod]
fn pi() -> PyEmlTree {
PyEmlTree {
inner: RustCanonical::pi(),
}
}
#[staticmethod]
fn neg(x: &PyEmlTree) -> PyEmlTree {
PyEmlTree {
inner: RustCanonical::neg(&x.inner),
}
}
#[staticmethod]
fn add(a: &PyEmlTree, b: &PyEmlTree) -> PyEmlTree {
PyEmlTree {
inner: RustCanonical::add(&a.inner, &b.inner),
}
}
#[staticmethod]
fn sub(a: &PyEmlTree, b: &PyEmlTree) -> PyEmlTree {
PyEmlTree {
inner: RustCanonical::sub(&a.inner, &b.inner),
}
}
#[staticmethod]
fn mul(a: &PyEmlTree, b: &PyEmlTree) -> PyEmlTree {
PyEmlTree {
inner: RustCanonical::mul(&a.inner, &b.inner),
}
}
#[staticmethod]
fn div(a: &PyEmlTree, b: &PyEmlTree) -> PyEmlTree {
PyEmlTree {
inner: RustCanonical::div(&a.inner, &b.inner),
}
}
#[staticmethod]
fn pow(a: &PyEmlTree, b: &PyEmlTree) -> PyEmlTree {
PyEmlTree {
inner: RustCanonical::pow(&a.inner, &b.inner),
}
}
#[staticmethod]
fn sin(x: &PyEmlTree) -> PyEmlTree {
PyEmlTree {
inner: RustCanonical::sin(&x.inner),
}
}
#[staticmethod]
fn cos(x: &PyEmlTree) -> PyEmlTree {
PyEmlTree {
inner: RustCanonical::cos(&x.inner),
}
}
#[staticmethod]
fn tan(x: &PyEmlTree) -> PyEmlTree {
PyEmlTree {
inner: RustCanonical::tan(&x.inner),
}
}
#[staticmethod]
fn arcsin(x: &PyEmlTree) -> PyEmlTree {
PyEmlTree {
inner: RustCanonical::arcsin(&x.inner),
}
}
#[staticmethod]
fn arccos(x: &PyEmlTree) -> PyEmlTree {
PyEmlTree {
inner: RustCanonical::arccos(&x.inner),
}
}
#[staticmethod]
fn arctan(x: &PyEmlTree) -> PyEmlTree {
PyEmlTree {
inner: RustCanonical::arctan(&x.inner),
}
}
#[staticmethod]
fn sinh(x: &PyEmlTree) -> PyEmlTree {
PyEmlTree {
inner: RustCanonical::sinh(&x.inner),
}
}
#[staticmethod]
fn cosh(x: &PyEmlTree) -> PyEmlTree {
PyEmlTree {
inner: RustCanonical::cosh(&x.inner),
}
}
#[staticmethod]
fn tanh(x: &PyEmlTree) -> PyEmlTree {
PyEmlTree {
inner: RustCanonical::tanh(&x.inner),
}
}
#[staticmethod]
fn arcsinh(x: &PyEmlTree) -> PyEmlTree {
PyEmlTree {
inner: RustCanonical::arcsinh(&x.inner),
}
}
#[staticmethod]
fn arccosh(x: &PyEmlTree) -> PyEmlTree {
PyEmlTree {
inner: RustCanonical::arccosh(&x.inner),
}
}
#[staticmethod]
fn arctanh(x: &PyEmlTree) -> PyEmlTree {
PyEmlTree {
inner: RustCanonical::arctanh(&x.inner),
}
}
#[staticmethod]
fn sqrt(x: &PyEmlTree) -> PyEmlTree {
PyEmlTree {
inner: RustCanonical::sqrt(&x.inner),
}
}
#[staticmethod]
fn abs(x: &PyEmlTree) -> PyEmlTree {
PyEmlTree {
inner: RustCanonical::abs(&x.inner),
}
}
#[staticmethod]
fn square(x: &PyEmlTree) -> PyEmlTree {
PyEmlTree {
inner: RustCanonical::square(&x.inner),
}
}
#[staticmethod]
fn reciprocal(x: &PyEmlTree) -> PyEmlTree {
PyEmlTree {
inner: RustCanonical::reciprocal(&x.inner),
}
}
#[staticmethod]
fn nat(n: u64) -> PyResult<PyEmlTree> {
RustCanonical::nat(n)
.map(|t| PyEmlTree { inner: t })
.map_err(|e| PyValueError::new_err(e.to_string()))
}
#[staticmethod]
fn zero() -> PyEmlTree {
PyEmlTree {
inner: RustCanonical::zero(),
}
}
#[staticmethod]
fn neg_one() -> PyEmlTree {
PyEmlTree {
inner: RustCanonical::neg_one(),
}
}
#[staticmethod]
fn imag_unit() -> PyEmlTree {
PyEmlTree {
inner: RustCanonical::imag_unit(),
}
}
}
#[pyclass(name = "LoweredOp", module = "scirs2.symbolic", skip_from_py_object)]
#[derive(Clone)]
pub struct PyLoweredOp {
inner: RustLoweredOp,
}
#[pymethods]
impl PyLoweredOp {
fn count_vars(&self) -> usize {
self.inner.count_vars()
}
fn structural_hash(&self) -> (u64, u64) {
let h = self.inner.structural_hash();
((h >> 64) as u64, (h & 0xFFFF_FFFF_FFFF_FFFF) as u64)
}
fn __repr__(&self) -> String {
format!("LoweredOp(count_vars={})", self.count_vars())
}
}
#[pyfunction]
fn lower(tree: &PyEmlTree) -> PyLoweredOp {
PyLoweredOp {
inner: rust_lower(&tree.inner),
}
}
#[pyfunction]
fn simplify(op: &PyLoweredOp) -> PyLoweredOp {
PyLoweredOp {
inner: rust_simplify_op(&op.inner),
}
}
#[pyfunction]
fn grad(op: &PyLoweredOp, wrt: usize) -> PyLoweredOp {
PyLoweredOp {
inner: rust_grad(&op.inner, wrt),
}
}
#[pyfunction]
fn eval_real(op: &PyLoweredOp, vars: Vec<f64>) -> PyResult<f64> {
let ctx = EvalCtx::new(&vars);
rust_eval_real(&op.inner, &ctx).map_err(|e| PyRuntimeError::new_err(e.to_string()))
}
#[pyfunction]
#[pyo3(signature = (
features,
targets,
max_iter = 50,
top_n = 3,
beam_width = 32,
max_depth = 6,
max_nodes = 20,
))]
#[allow(clippy::too_many_arguments)]
fn discover(
py: Python<'_>,
features: PyReadonlyArray2<f64>,
targets: PyReadonlyArray1<f64>,
max_iter: usize,
top_n: usize,
beam_width: usize,
max_depth: usize,
max_nodes: usize,
) -> PyResult<Vec<PyDiscoveredFormula>> {
let features_arr = features.as_array();
let targets_arr = targets.as_array();
let config = RustSrConfig::default()
.with_max_iter(max_iter)
.with_top_n(top_n)
.with_beam_width(beam_width)
.with_max_depth(max_depth)
.with_max_nodes(max_nodes);
let results = py.detach(|| rust_discover(features_arr, targets_arr, &config));
Ok(results
.into_iter()
.map(|f| PyDiscoveredFormula {
op: PyLoweredOp { inner: f.op },
mse: f.fitness.mse,
r_squared: f.fitness.r_squared,
combined: f.fitness.combined,
node_count: f.node_count,
n_vars: f.n_vars,
})
.collect())
}
#[pyclass(
name = "DiscoveredFormula",
module = "scirs2.symbolic",
skip_from_py_object
)]
#[derive(Clone)]
pub struct PyDiscoveredFormula {
#[pyo3(get)]
pub op: PyLoweredOp,
#[pyo3(get)]
pub mse: f64,
#[pyo3(get)]
pub r_squared: f64,
#[pyo3(get)]
pub combined: f64,
#[pyo3(get)]
pub node_count: usize,
#[pyo3(get)]
pub n_vars: usize,
}
#[pymethods]
impl PyDiscoveredFormula {
fn __repr__(&self) -> String {
format!(
"DiscoveredFormula(mse={:.6}, r_squared={:.6}, n_nodes={}, n_vars={})",
self.mse, self.r_squared, self.node_count, self.n_vars
)
}
}
pub fn register_module(m: &Bound<'_, PyModule>) -> PyResult<()> {
let py = m.py();
let symbolic = PyModule::new(py, "symbolic")?;
symbolic.add_class::<PyEmlTree>()?;
symbolic.add_class::<PyCanonical>()?;
symbolic.add_class::<PyLoweredOp>()?;
symbolic.add_class::<PyDiscoveredFormula>()?;
symbolic.add_function(wrap_pyfunction!(lower, &symbolic)?)?;
symbolic.add_function(wrap_pyfunction!(simplify, &symbolic)?)?;
symbolic.add_function(wrap_pyfunction!(grad, &symbolic)?)?;
symbolic.add_function(wrap_pyfunction!(eval_real, &symbolic)?)?;
symbolic.add_function(wrap_pyfunction!(discover, &symbolic)?)?;
symbolic.add(
"__doc__",
"Symbolic mathematics — EML substrate, evaluation, gradient, and \
beam-search symbolic regression.\n\nClasses:\n - EmlTree: uniform \
binary EML tree (constant 1 + var leaves + binary eml nodes).\n - \
Canonical: namespace of elementary-function constructors.\n - \
LoweredOp: flat operator IR produced by lower(tree).\n - \
DiscoveredFormula: result of discover().\n\nFunctions:\n - \
lower(tree) -> LoweredOp\n - simplify(op) -> LoweredOp\n - grad(op, wrt) \
-> LoweredOp\n - eval_real(op, vars) -> float\n - discover(features, \
targets, ...) -> list[DiscoveredFormula]",
)?;
m.add_submodule(&symbolic)?;
Ok(())
}