Skip to main content

phop_core/
codegen.rs

1//! Code generation for discovered EML expressions (Rust, Python/NumPy).
2//!
3//! LaTeX rendering is delegated to `oxieml` (`tree.lower().simplify().to_latex()`); this
4//! module emits executable source in other targets directly from the EML AST.
5
6use oxieml::{EmlNode, EmlTree};
7
8fn rust_expr(node: &EmlNode) -> String {
9    match node {
10        EmlNode::One => "1.0_f64".to_string(),
11        EmlNode::Const(c) => format!("{c}_f64"),
12        EmlNode::Var(i) => format!("x{i}"),
13        EmlNode::Eml { left, right } => format!(
14            "(({}).exp() - ({}).ln())",
15            rust_expr(left),
16            rust_expr(right)
17        ),
18    }
19}
20
21fn numpy_expr(node: &EmlNode) -> String {
22    match node {
23        EmlNode::One => "1.0".to_string(),
24        EmlNode::Const(c) => format!("{c}"),
25        EmlNode::Var(i) => format!("x{i}"),
26        EmlNode::Eml { left, right } => format!(
27            "(np.exp({}) - np.log({}))",
28            numpy_expr(left),
29            numpy_expr(right)
30        ),
31    }
32}
33
34fn sympy_expr(node: &EmlNode) -> String {
35    match node {
36        EmlNode::One => "1".to_string(),
37        EmlNode::Const(c) => format!("{c}"),
38        EmlNode::Var(i) => format!("x{i}"),
39        EmlNode::Eml { left, right } => {
40            format!("(exp({}) - log({}))", sympy_expr(left), sympy_expr(right))
41        }
42    }
43}
44
45/// Generate a standalone Rust function `f(x0, x1, ...) -> f64` for the tree.
46#[must_use]
47pub fn rust_code(tree: &EmlTree) -> String {
48    let n = tree.num_vars().max(1);
49    let args = (0..n)
50        .map(|i| format!("x{i}: f64"))
51        .collect::<Vec<_>>()
52        .join(", ");
53    format!("fn f({args}) -> f64 {{ {} }}", rust_expr(&tree.root))
54}
55
56/// Generate a NumPy-compatible Python lambda body for the tree.
57#[must_use]
58pub fn numpy_code(tree: &EmlTree) -> String {
59    let n = tree.num_vars().max(1);
60    let args = (0..n)
61        .map(|i| format!("x{i}"))
62        .collect::<Vec<_>>()
63        .join(", ");
64    format!("lambda {args}: {}", numpy_expr(&tree.root))
65}
66
67/// Generate a SymPy expression string for the tree.
68///
69/// The result uses SymPy's `exp` / `log` and bare symbols `x0, x1, …`; it parses under
70/// `from sympy import exp, log, symbols` (with the `xi` declared as symbols) and can then be
71/// `simplify`-ed Python-side. The variables are listed in the doc comment of the emitted code.
72#[must_use]
73pub fn sympy_code(tree: &EmlTree) -> String {
74    let n = tree.num_vars().max(1);
75    let syms = (0..n)
76        .map(|i| format!("x{i}"))
77        .collect::<Vec<_>>()
78        .join(" ");
79    format!("# symbols('{syms}')\n{}", sympy_expr(&tree.root))
80}
81
82#[cfg(test)]
83mod tests {
84    use super::*;
85
86    #[test]
87    fn rust_codegen_for_exp() {
88        let tree = oxieml::Canonical::exp(&EmlTree::var(0));
89        let code = rust_code(&tree);
90        assert!(code.contains("fn f("));
91        assert!(code.contains(".exp()"));
92    }
93
94    #[test]
95    fn numpy_codegen_for_exp() {
96        let tree = oxieml::Canonical::exp(&EmlTree::var(0));
97        let code = numpy_code(&tree);
98        assert!(code.contains("np.exp"));
99    }
100
101    #[test]
102    fn sympy_codegen_for_exp() {
103        let tree = oxieml::Canonical::exp(&EmlTree::var(0));
104        let code = sympy_code(&tree);
105        assert!(code.contains("exp("));
106        assert!(code.contains("log("));
107        assert!(code.contains("x0"));
108    }
109}