Skip to main content

symbolic_regression/
symbolic_regression.rs

1//! Symbolic Regression via genetic programming over expression graphs.
2//!
3//! Given noisy data from `f(x) = x² · sin(x)`, evolves a population of
4//! expression trees to rediscover the formula.
5//!
6//! ```sh
7//! cargo run --example symbolic_regression -p tang-expr
8//! ```
9
10use tang_expr::{ExprGraph, ExprId};
11
12// --- Inline LCG PRNG (matches `collocation_random` pattern) -----------------
13
14struct Lcg(u64);
15
16impl Lcg {
17    fn new(seed: u64) -> Self {
18        Self(seed)
19    }
20    fn next(&mut self) -> u64 {
21        self.0 = self
22            .0
23            .wrapping_mul(6364136223846793005)
24            .wrapping_add(1442695040888963407);
25        self.0
26    }
27    fn uniform(&mut self) -> f64 {
28        (self.next() >> 11) as f64 / (1u64 << 53) as f64
29    }
30    fn range(&mut self, n: usize) -> usize {
31        (self.uniform() * n as f64) as usize % n
32    }
33}
34
35// --- AST representation for genetic programming -----------------------------
36
37#[derive(Clone, Debug)]
38enum Expr {
39    X,
40    Lit(f64),
41    Add(Box<Expr>, Box<Expr>),
42    Mul(Box<Expr>, Box<Expr>),
43    Sin(Box<Expr>),
44    Neg(Box<Expr>),
45}
46
47impl Expr {
48    /// Count nodes in this tree.
49    fn size(&self) -> usize {
50        match self {
51            Expr::X | Expr::Lit(_) => 1,
52            Expr::Sin(a) | Expr::Neg(a) => 1 + a.size(),
53            Expr::Add(a, b) | Expr::Mul(a, b) => 1 + a.size() + b.size(),
54        }
55    }
56
57    /// Convert to ExprGraph node, returning the root ExprId.
58    fn to_expr(&self, g: &mut ExprGraph) -> ExprId {
59        match self {
60            Expr::X => g.var(0),
61            Expr::Lit(v) => g.lit(*v),
62            Expr::Add(a, b) => {
63                let a = a.to_expr(g);
64                let b = b.to_expr(g);
65                g.add(a, b)
66            }
67            Expr::Mul(a, b) => {
68                let a = a.to_expr(g);
69                let b = b.to_expr(g);
70                g.mul(a, b)
71            }
72            Expr::Sin(a) => {
73                let a = a.to_expr(g);
74                g.sin(a)
75            }
76            Expr::Neg(a) => {
77                let a = a.to_expr(g);
78                g.neg(a)
79            }
80        }
81    }
82
83    /// Format as string using ExprGraph's fmt_expr.
84    fn format(&self) -> String {
85        let mut g = ExprGraph::new();
86        let root = self.to_expr(&mut g);
87        g.fmt_expr(root)
88    }
89
90    /// Evaluate at a point using ExprGraph's compiled eval.
91    fn eval_at(&self, x: f64) -> f64 {
92        let mut g = ExprGraph::new();
93        let root = self.to_expr(&mut g);
94        g.eval(root, &[x])
95    }
96}
97
98// --- Random expression generation -------------------------------------------
99
100fn random_expr(depth: usize, rng: &mut Lcg) -> Expr {
101    if depth == 0 || (depth < 3 && rng.uniform() < 0.3) {
102        return match rng.range(3) {
103            0 => Expr::X,
104            1 => Expr::Lit((rng.uniform() * 4.0 - 2.0) * 10.0_f64.powi(-((rng.uniform() * 2.0) as i32))),
105            _ => Expr::X,
106        };
107    }
108
109    match rng.range(5) {
110        0 => Expr::Add(
111            Box::new(random_expr(depth - 1, rng)),
112            Box::new(random_expr(depth - 1, rng)),
113        ),
114        1 | 2 => Expr::Mul(
115            Box::new(random_expr(depth - 1, rng)),
116            Box::new(random_expr(depth - 1, rng)),
117        ),
118        3 => Expr::Sin(Box::new(random_expr(depth - 1, rng))),
119        _ => Expr::Neg(Box::new(random_expr(depth - 1, rng))),
120    }
121}
122
123// --- Mutation operators ------------------------------------------------------
124
125/// Grow mutation: replace a random subtree with a new random one.
126fn mutate_grow(expr: &Expr, rng: &mut Lcg) -> Expr {
127    let size = expr.size();
128    let target = rng.range(size);
129    grow_at(expr, target, &mut 0, rng)
130}
131
132fn grow_at(expr: &Expr, target: usize, counter: &mut usize, rng: &mut Lcg) -> Expr {
133    if *counter == target {
134        *counter += expr.size(); // skip subtree
135        return random_expr(3, rng);
136    }
137    *counter += 1;
138    match expr {
139        Expr::X => Expr::X,
140        Expr::Lit(v) => Expr::Lit(*v),
141        Expr::Add(a, b) => Expr::Add(
142            Box::new(grow_at(a, target, counter, rng)),
143            Box::new(grow_at(b, target, counter, rng)),
144        ),
145        Expr::Mul(a, b) => Expr::Mul(
146            Box::new(grow_at(a, target, counter, rng)),
147            Box::new(grow_at(b, target, counter, rng)),
148        ),
149        Expr::Sin(a) => Expr::Sin(Box::new(grow_at(a, target, counter, rng))),
150        Expr::Neg(a) => Expr::Neg(Box::new(grow_at(a, target, counter, rng))),
151    }
152}
153
154/// Point mutation: change a single node's operation.
155fn mutate_point(expr: &Expr, rng: &mut Lcg) -> Expr {
156    let size = expr.size();
157    let target = rng.range(size);
158    point_at(expr, target, &mut 0, rng)
159}
160
161fn point_at(expr: &Expr, target: usize, counter: &mut usize, rng: &mut Lcg) -> Expr {
162    if *counter == target {
163        *counter += 1;
164        return match expr {
165            Expr::X => Expr::Lit(rng.uniform() * 2.0 - 1.0),
166            Expr::Lit(_) => Expr::X,
167            Expr::Add(a, b) => Expr::Mul(a.clone(), b.clone()),
168            Expr::Mul(a, b) => Expr::Add(a.clone(), b.clone()),
169            Expr::Sin(a) => Expr::Neg(a.clone()),
170            Expr::Neg(a) => Expr::Sin(a.clone()),
171        };
172    }
173    *counter += 1;
174    match expr {
175        Expr::X => Expr::X,
176        Expr::Lit(v) => Expr::Lit(*v),
177        Expr::Add(a, b) => Expr::Add(
178            Box::new(point_at(a, target, counter, rng)),
179            Box::new(point_at(b, target, counter, rng)),
180        ),
181        Expr::Mul(a, b) => Expr::Mul(
182            Box::new(point_at(a, target, counter, rng)),
183            Box::new(point_at(b, target, counter, rng)),
184        ),
185        Expr::Sin(a) => Expr::Sin(Box::new(point_at(a, target, counter, rng))),
186        Expr::Neg(a) => Expr::Neg(Box::new(point_at(a, target, counter, rng))),
187    }
188}
189
190/// Simplify mutation: convert to ExprGraph, simplify, convert back.
191/// Falls back to identity if conversion back would be complex.
192fn mutate_simplify(expr: &Expr) -> Expr {
193    let mut g = ExprGraph::new();
194    let root = expr.to_expr(&mut g);
195    let simplified = g.simplify(root);
196    let s = g.fmt_expr(simplified);
197
198    // Quick check: if simplified form is much shorter, it's better
199    let orig = expr.format();
200    if s.len() < orig.len() {
201        // Re-evaluate to confirm it still works
202        let test = g.eval::<f64>(simplified, &[1.0]);
203        let orig_test = expr.eval_at(1.0);
204        if (test - orig_test).abs() < 1e-10 || (test.is_nan() && orig_test.is_nan()) {
205            // Return a new tree built from the simplified graph
206            return expr_from_graph(&g, simplified);
207        }
208    }
209    expr.clone()
210}
211
212/// Reconstruct an Expr tree from an ExprGraph node.
213fn expr_from_graph(g: &ExprGraph, id: ExprId) -> Expr {
214    match g.node(id) {
215        tang_expr::node::Node::Var(_) => Expr::X,
216        tang_expr::node::Node::Lit(bits) => {
217            let v = f64::from_bits(bits);
218            if v == 0.0 {
219                Expr::Lit(0.0)
220            } else {
221                Expr::Lit(v)
222            }
223        }
224        tang_expr::node::Node::Add(a, b) => {
225            Expr::Add(Box::new(expr_from_graph(g, a)), Box::new(expr_from_graph(g, b)))
226        }
227        tang_expr::node::Node::Mul(a, b) => {
228            Expr::Mul(Box::new(expr_from_graph(g, a)), Box::new(expr_from_graph(g, b)))
229        }
230        tang_expr::node::Node::Neg(a) => Expr::Neg(Box::new(expr_from_graph(g, a))),
231        tang_expr::node::Node::Sin(a) => Expr::Sin(Box::new(expr_from_graph(g, a))),
232        // For operations we don't represent in our AST, just evaluate as literal
233        _ => {
234            let v = g.eval::<f64>(id, &[1.0]); // fallback
235            Expr::Lit(v)
236        }
237    }
238}
239
240// --- Fitness evaluation ------------------------------------------------------
241
242fn target(x: f64) -> f64 {
243    x * x * x.sin()
244}
245
246fn generate_data(n: usize, rng: &mut Lcg) -> Vec<(f64, f64)> {
247    (0..n)
248        .map(|i| {
249            let x = -3.0 + 6.0 * i as f64 / (n - 1) as f64;
250            let noise = (rng.uniform() - 0.5) * 0.01;
251            (x, target(x) + noise)
252        })
253        .collect()
254}
255
256/// Evaluate fitness: MSE + complexity penalty.
257fn fitness(expr: &Expr, data: &[(f64, f64)]) -> f64 {
258    let mut g = ExprGraph::new();
259    let root = expr.to_expr(&mut g);
260    let compiled = g.compile(root);
261
262    let mut mse = 0.0;
263    for &(x, y) in data {
264        let pred = compiled(&[x]);
265        if pred.is_nan() || pred.is_infinite() {
266            return f64::INFINITY;
267        }
268        let err = pred - y;
269        mse += err * err;
270    }
271    mse /= data.len() as f64;
272
273    let penalty = 0.001 * expr.size() as f64;
274    mse + penalty
275}
276
277// --- Selection ---------------------------------------------------------------
278
279fn tournament<'a>(pop: &'a [(Expr, f64)], k: usize, rng: &mut Lcg) -> &'a Expr {
280    let mut best_idx = rng.range(pop.len());
281    for _ in 1..k {
282        let idx = rng.range(pop.len());
283        if pop[idx].1 < pop[best_idx].1 {
284            best_idx = idx;
285        }
286    }
287    &pop[best_idx].0
288}
289
290// --- Main --------------------------------------------------------------------
291
292fn main() {
293    println!("=== Symbolic Regression ===\n");
294    println!("target: f(x) = x^2 * sin(x)\n");
295
296    let mut rng = Lcg::new(42);
297    let data = generate_data(50, &mut rng);
298
299    const POP_SIZE: usize = 200;
300    const GENERATIONS: usize = 100;
301    const TOURNAMENT_K: usize = 5;
302
303    // Initialize population
304    let mut population: Vec<(Expr, f64)> = (0..POP_SIZE)
305        .map(|_| {
306            let expr = random_expr(4, &mut rng);
307            let fit = fitness(&expr, &data);
308            (expr, fit)
309        })
310        .collect();
311
312    let mut best_expr: Option<Expr> = None;
313    let mut best_fitness = f64::INFINITY;
314
315    for gen in 0..GENERATIONS {
316        // Track best
317        for (expr, fit) in &population {
318            if *fit < best_fitness {
319                best_fitness = *fit;
320                best_expr = Some(expr.clone());
321            }
322        }
323
324        if (gen + 1) % 10 == 0 || gen == 0 {
325            let b = best_expr.as_ref().unwrap();
326            println!(
327                "gen {:>3}: best fitness = {:.6}  nodes = {:>3}  expr = {}",
328                gen + 1,
329                best_fitness,
330                b.size(),
331                b.format(),
332            );
333        }
334
335        // Build next generation
336        let mut next_pop = Vec::with_capacity(POP_SIZE);
337
338        // Elitism: keep top 5
339        let mut indices: Vec<usize> = (0..population.len()).collect();
340        indices.sort_by(|&a, &b| population[a].1.partial_cmp(&population[b].1).unwrap());
341        for &i in indices.iter().take(5) {
342            next_pop.push(population[i].clone());
343        }
344
345        // Fill rest via mutation
346        while next_pop.len() < POP_SIZE {
347            let parent = tournament(&population, TOURNAMENT_K, &mut rng);
348            let child = match rng.range(10) {
349                0..=3 => mutate_grow(parent, &mut rng),
350                4..=6 => mutate_point(parent, &mut rng),
351                7..=8 => mutate_simplify(parent),
352                _ => random_expr(4, &mut rng), // fresh blood
353            };
354
355            // Skip overly large expressions
356            if child.size() > 50 {
357                continue;
358            }
359
360            let fit = fitness(&child, &data);
361            next_pop.push((child, fit));
362        }
363
364        population = next_pop;
365    }
366
367    // Final results
368    println!();
369    let best = best_expr.unwrap();
370    let mut g = ExprGraph::new();
371    let root = best.to_expr(&mut g);
372    let simplified = g.simplify(root);
373    let expr_str = g.fmt_expr(simplified);
374    println!("best expression: {}", expr_str);
375    println!("best fitness:    {:.6}", best_fitness);
376
377    // Verify on test points
378    println!("\nverification:");
379    let compiled = g.compile(simplified);
380    for &x in &[-2.0, -1.0, 0.0, 1.0, 2.0] {
381        let pred = compiled(&[x]);
382        let exact = target(x);
383        println!(
384            "  f({:>5.1}) = {:>8.4}  (predicted: {:>8.4}, error: {:>8.4})",
385            x, exact, pred, (pred - exact).abs()
386        );
387    }
388
389    // Symbolic derivative via ExprGraph
390    let dx = g.diff(simplified, 0);
391    let dx = g.simplify(dx);
392    println!("\nsymbolic derivative: d/dx [{}] = {}", expr_str, g.fmt_expr(dx));
393}