use crate::compression::mdl_score;
use crate::equation::Expression;
use crate::error::{LmmError, Result};
use crate::tensor::Tensor;
use crate::traits::Discoverable;
use rand::{Rng, RngExt};
pub struct SymbolicRegression {
pub max_depth: usize,
pub population_size: usize,
pub iterations: usize,
pub variable_names: Vec<String>,
}
impl SymbolicRegression {
pub fn new(max_depth: usize, iterations: usize) -> Self {
Self {
max_depth,
population_size: 50,
iterations,
variable_names: vec!["x".into()],
}
}
pub fn with_variables(mut self, vars: Vec<String>) -> Self {
self.variable_names = vars;
self
}
pub fn with_population(mut self, size: usize) -> Self {
self.population_size = size;
self
}
fn seed_templates<R: Rng>(&self, rng: &mut R) -> Vec<Expression> {
let mut templates = Vec::new();
for var in &self.variable_names {
let x = Expression::Variable(var.clone());
let a: f64 = rng.random_range(0.5..=5.0);
let b: f64 = rng.random_range(-10.0..=10.0);
let c: f64 = rng.random_range(-5.0..=5.0);
templates.push(Expression::Add(
Box::new(Expression::Mul(
Box::new(Expression::Constant(a)),
Box::new(x.clone()),
)),
Box::new(Expression::Constant(b)),
));
templates.push(Expression::Add(
Box::new(Expression::Add(
Box::new(Expression::Mul(
Box::new(Expression::Constant(a)),
Box::new(Expression::Pow(
Box::new(x.clone()),
Box::new(Expression::Constant(2.0)),
)),
)),
Box::new(Expression::Mul(
Box::new(Expression::Constant(b)),
Box::new(x.clone()),
)),
)),
Box::new(Expression::Constant(c)),
));
templates.push(Expression::Add(
Box::new(Expression::Mul(
Box::new(Expression::Constant(a)),
Box::new(Expression::Sin(Box::new(Expression::Mul(
Box::new(Expression::Constant(0.1)),
Box::new(x.clone()),
)))),
)),
Box::new(Expression::Constant(b)),
));
templates.push(x);
}
templates
}
fn has_variables(expr: &Expression) -> bool {
match expr {
Expression::Variable(_) => true,
Expression::Constant(_) => false,
Expression::Neg(e)
| Expression::Abs(e)
| Expression::Sin(e)
| Expression::Cos(e)
| Expression::Exp(e)
| Expression::Log(e) => Self::has_variables(e),
Expression::Add(l, r)
| Expression::Sub(l, r)
| Expression::Mul(l, r)
| Expression::Div(l, r)
| Expression::Pow(l, r) => Self::has_variables(l) || Self::has_variables(r),
}
}
fn random_expr<R: Rng>(&self, rng: &mut R, depth: usize) -> Expression {
if depth >= self.max_depth || (depth > 0 && rng.random_bool(0.4)) {
return self.random_leaf(rng);
}
let choice: u8 = rng.random_range(0..8);
match choice {
0 => Expression::Add(
Box::new(self.random_expr(rng, depth + 1)),
Box::new(self.random_expr(rng, depth + 1)),
),
1 => Expression::Sub(
Box::new(self.random_expr(rng, depth + 1)),
Box::new(self.random_expr(rng, depth + 1)),
),
2 => Expression::Mul(
Box::new(self.random_expr(rng, depth + 1)),
Box::new(self.random_expr(rng, depth + 1)),
),
3 => Expression::Div(
Box::new(self.random_expr(rng, depth + 1)),
Box::new(self.random_expr(rng, depth + 1)),
),
4 => Expression::Sin(Box::new(self.random_expr(rng, depth + 1))),
5 => Expression::Cos(Box::new(self.random_expr(rng, depth + 1))),
6 => Expression::Exp(Box::new(self.random_expr(rng, depth + 1))),
_ => {
let b = self.random_expr(rng, depth + 1);
let exp_val = f64::from(rng.random_range(2u32..=3));
Expression::Pow(Box::new(b), Box::new(Expression::Constant(exp_val)))
}
}
}
fn random_leaf<R: Rng>(&self, rng: &mut R) -> Expression {
if !self.variable_names.is_empty() && rng.random_bool(0.6) {
let idx = rng.random_range(0..self.variable_names.len());
Expression::Variable(self.variable_names[idx].clone())
} else {
let v: f64 = rng.random_range(-5.0..=5.0);
Expression::Constant((v * 10.0).round() / 10.0)
}
}
fn mutate<R: Rng>(&self, expr: &Expression, rng: &mut R) -> Expression {
if rng.random_bool(0.3) {
return self.random_expr(rng, 0);
}
match expr {
Expression::Constant(c) => {
let delta: f64 = rng.random_range(-1.0..=1.0);
Expression::Constant((c + delta * 0.5).clamp(-100.0, 100.0))
}
Expression::Variable(_) => self.random_leaf(rng),
Expression::Add(l, r) => {
Expression::Add(Box::new(self.mutate(l, rng)), Box::new(self.mutate(r, rng)))
}
Expression::Sub(l, r) => {
Expression::Sub(Box::new(self.mutate(l, rng)), Box::new(self.mutate(r, rng)))
}
Expression::Mul(l, r) => {
Expression::Mul(Box::new(self.mutate(l, rng)), Box::new(self.mutate(r, rng)))
}
Expression::Div(l, r) => {
Expression::Div(Box::new(self.mutate(l, rng)), Box::new(self.mutate(r, rng)))
}
Expression::Sin(e) => Expression::Sin(Box::new(self.mutate(e, rng))),
Expression::Cos(e) => Expression::Cos(Box::new(self.mutate(e, rng))),
Expression::Exp(e) => Expression::Exp(Box::new(self.mutate(e, rng))),
Expression::Log(e) => Expression::Log(Box::new(self.mutate(e, rng))),
Expression::Pow(b, e) => Expression::Pow(Box::new(self.mutate(b, rng)), e.clone()),
Expression::Neg(e) => Expression::Neg(Box::new(self.mutate(e, rng))),
Expression::Abs(e) => Expression::Abs(Box::new(self.mutate(e, rng))),
}
}
fn crossover<R: Rng>(
&self,
parent_a: &Expression,
parent_b: &Expression,
rng: &mut R,
) -> Expression {
if rng.random_bool(0.5) {
self.swap_subtree(parent_a, parent_b, rng)
} else {
self.swap_subtree(parent_b, parent_a, rng)
}
}
fn swap_subtree<R: Rng>(
&self,
base: &Expression,
donor: &Expression,
rng: &mut R,
) -> Expression {
if rng.random_bool(0.3) {
return donor.clone();
}
match base {
Expression::Add(l, r) => Expression::Add(
Box::new(self.swap_subtree(l, donor, rng)),
Box::new(self.swap_subtree(r, donor, rng)),
),
Expression::Sub(l, r) => Expression::Sub(
Box::new(self.swap_subtree(l, donor, rng)),
Box::new(self.swap_subtree(r, donor, rng)),
),
Expression::Mul(l, r) => Expression::Mul(
Box::new(self.swap_subtree(l, donor, rng)),
Box::new(self.swap_subtree(r, donor, rng)),
),
Expression::Div(l, r) => Expression::Div(
Box::new(self.swap_subtree(l, donor, rng)),
Box::new(self.swap_subtree(r, donor, rng)),
),
Expression::Sin(e) => Expression::Sin(Box::new(self.swap_subtree(e, donor, rng))),
Expression::Cos(e) => Expression::Cos(Box::new(self.swap_subtree(e, donor, rng))),
Expression::Exp(e) => Expression::Exp(Box::new(self.swap_subtree(e, donor, rng))),
Expression::Log(e) => Expression::Log(Box::new(self.swap_subtree(e, donor, rng))),
Expression::Neg(e) => Expression::Neg(Box::new(self.swap_subtree(e, donor, rng))),
Expression::Abs(e) => Expression::Abs(Box::new(self.swap_subtree(e, donor, rng))),
Expression::Pow(b, e) => {
Expression::Pow(Box::new(self.swap_subtree(b, donor, rng)), e.clone())
}
leaf => leaf.clone(),
}
}
fn tournament_select<'a, R: Rng>(
population: &'a [Expression],
fitnesses: &[f64],
rng: &mut R,
k: usize,
) -> &'a Expression {
let n = population.len();
let mut best_idx = rng.random_range(0..n);
for _ in 1..k {
let idx = rng.random_range(0..n);
if fitnesses[idx] < fitnesses[best_idx] {
best_idx = idx;
}
}
&population[best_idx]
}
pub fn fit(&self, inputs: &[Vec<f64>], targets: &[f64]) -> Result<Expression> {
if inputs.is_empty() || targets.is_empty() {
return Err(LmmError::Discovery("Empty training data".into()));
}
let mut rng = rand::rng();
let templates = self.seed_templates(&mut rng);
let mut population: Vec<Expression> = templates;
while population.len() < self.population_size {
population.push(self.random_expr(&mut rng, 0));
}
let has_vars = !self.variable_names.is_empty();
let initial_candidate = population
.iter()
.find(|e| !has_vars || Self::has_variables(e))
.cloned()
.unwrap_or_else(|| population[0].clone());
let mut best_expr = initial_candidate.clone();
let mut best_score = {
let s = mdl_score(&initial_candidate, inputs, targets);
if s.is_finite() { s } else { f64::MAX }
};
for _ in 0..self.iterations {
let fitnesses: Vec<f64> = population
.iter()
.map(|e| {
let score = mdl_score(e, inputs, targets);
if score.is_nan() || score.is_infinite() {
1e9
} else {
score
}
})
.collect();
for (i, &score) in fitnesses.iter().enumerate() {
if score < best_score && (!has_vars || Self::has_variables(&population[i])) {
best_score = score;
best_expr = population[i].clone();
}
}
let mut new_pop = vec![best_expr.clone()];
while new_pop.len() < self.population_size {
let parent_a = Self::tournament_select(&population, &fitnesses, &mut rng, 5);
let op: u8 = rng.random_range(0..3);
let child = match op {
0 => {
let parent_b =
Self::tournament_select(&population, &fitnesses, &mut rng, 5);
self.crossover(parent_a, parent_b, &mut rng)
}
1 => self.mutate(parent_a, &mut rng),
_ => parent_a.clone(),
};
let simplified = child.simplify();
if has_vars && !Self::has_variables(&simplified) && rng.random_bool(0.7) {
new_pop.push(self.random_expr(&mut rng, 0));
} else {
new_pop.push(simplified);
}
}
population = new_pop;
}
Ok(best_expr.simplify())
}
}
impl Discoverable for SymbolicRegression {
fn discover(data: &[Tensor], targets: &[f64]) -> Result<Expression> {
if data.is_empty() {
return Ok(Expression::Variable("x".into()));
}
let inputs: Vec<Vec<f64>> = data.iter().map(|t| t.data.clone()).collect();
let sr = SymbolicRegression::new(3, 50);
sr.fit(&inputs, targets)
}
}