use wasm_bindgen::prelude::*;
#[wasm_bindgen]
pub struct WasmSymRegConfig {
inner: crate::symreg::SymRegConfig,
max_formulas: usize,
}
#[wasm_bindgen]
impl WasmSymRegConfig {
#[wasm_bindgen(constructor)]
pub fn new() -> Self {
Self {
inner: crate::symreg::SymRegConfig::default(),
max_formulas: 10,
}
}
pub fn quick() -> Self {
Self {
inner: crate::symreg::SymRegConfig::quick(),
max_formulas: 10,
}
}
pub fn balanced() -> Self {
Self {
inner: crate::symreg::SymRegConfig::balanced(),
max_formulas: 10,
}
}
#[wasm_bindgen(getter)]
pub fn max_depth(&self) -> usize {
self.inner.max_depth
}
#[wasm_bindgen(setter)]
pub fn set_max_depth(&mut self, v: usize) {
self.inner.max_depth = v;
}
#[wasm_bindgen(getter)]
pub fn max_formulas(&self) -> usize {
self.max_formulas
}
#[wasm_bindgen(setter)]
pub fn set_max_formulas(&mut self, v: usize) {
self.max_formulas = v;
}
#[wasm_bindgen(getter)]
pub fn max_iter(&self) -> usize {
self.inner.max_iter
}
#[wasm_bindgen(setter)]
pub fn set_max_iter(&mut self, v: usize) {
self.inner.max_iter = v;
}
#[wasm_bindgen(getter)]
pub fn seed(&self) -> Option<u64> {
self.inner.seed
}
#[wasm_bindgen(setter)]
pub fn set_seed(&mut self, v: Option<u64>) {
self.inner.seed = v;
}
}
impl Default for WasmSymRegConfig {
fn default() -> Self {
Self::new()
}
}
#[wasm_bindgen]
pub struct WasmDiscoveredFormula {
inner: crate::symreg::DiscoveredFormula,
}
#[wasm_bindgen]
impl WasmDiscoveredFormula {
#[wasm_bindgen(getter)]
pub fn pretty(&self) -> String {
self.inner.pretty.clone()
}
#[wasm_bindgen(getter)]
pub fn mse(&self) -> f64 {
self.inner.mse
}
#[wasm_bindgen(getter)]
pub fn complexity(&self) -> usize {
self.inner.complexity
}
#[wasm_bindgen(getter)]
pub fn score(&self) -> f64 {
self.inner.score
}
pub fn to_latex(&self) -> String {
self.inner.eml_tree.lower().simplify().to_latex()
}
pub fn eval(&self, xs: &[f64]) -> f64 {
self.inner.eml_tree.lower().simplify().eval(xs)
}
}
#[wasm_bindgen]
pub struct WasmSymRegEngine {
config: crate::symreg::SymRegConfig,
max_formulas: usize,
}
#[wasm_bindgen]
impl WasmSymRegEngine {
#[wasm_bindgen(constructor)]
pub fn new(config: &WasmSymRegConfig) -> Self {
Self {
config: config.inner.clone(),
max_formulas: config.max_formulas,
}
}
pub fn discover(
&self,
x_flat: &[f64],
y_flat: &[f64],
n_samples: usize,
n_features: usize,
) -> Result<Vec<WasmDiscoveredFormula>, JsValue> {
if x_flat.len() != n_samples * n_features {
return Err(JsValue::from_str(&format!(
"x_flat.len()={} but n_samples*n_features={}",
x_flat.len(),
n_samples * n_features
)));
}
if y_flat.len() != n_samples {
return Err(JsValue::from_str(&format!(
"y_flat.len()={} but n_samples={}",
y_flat.len(),
n_samples
)));
}
let inputs: Vec<Vec<f64>> = (0..n_samples)
.map(|i| x_flat[i * n_features..(i + 1) * n_features].to_vec())
.collect();
let targets: Vec<f64> = y_flat.to_vec();
let engine = crate::symreg::SymRegEngine::new(self.config.clone());
engine
.discover(&inputs, &targets, n_features)
.map_err(|e| JsValue::from_str(&e.to_string()))
.map(|mut formulas| {
formulas.truncate(self.max_formulas);
formulas
.into_iter()
.map(|f| WasmDiscoveredFormula { inner: f })
.collect()
})
}
}