1use oxieml::{EmlNode, EmlTree};
7
8fn rust_expr(node: &EmlNode) -> String {
9 match node {
10 EmlNode::One => "1.0_f64".to_string(),
11 EmlNode::Const(c) => format!("{c}_f64"),
12 EmlNode::Var(i) => format!("x{i}"),
13 EmlNode::Eml { left, right } => format!(
14 "(({}).exp() - ({}).ln())",
15 rust_expr(left),
16 rust_expr(right)
17 ),
18 }
19}
20
21fn numpy_expr(node: &EmlNode) -> String {
22 match node {
23 EmlNode::One => "1.0".to_string(),
24 EmlNode::Const(c) => format!("{c}"),
25 EmlNode::Var(i) => format!("x{i}"),
26 EmlNode::Eml { left, right } => format!(
27 "(np.exp({}) - np.log({}))",
28 numpy_expr(left),
29 numpy_expr(right)
30 ),
31 }
32}
33
34fn sympy_expr(node: &EmlNode) -> String {
35 match node {
36 EmlNode::One => "1".to_string(),
37 EmlNode::Const(c) => format!("{c}"),
38 EmlNode::Var(i) => format!("x{i}"),
39 EmlNode::Eml { left, right } => {
40 format!("(exp({}) - log({}))", sympy_expr(left), sympy_expr(right))
41 }
42 }
43}
44
45#[must_use]
47pub fn rust_code(tree: &EmlTree) -> String {
48 let n = tree.num_vars().max(1);
49 let args = (0..n)
50 .map(|i| format!("x{i}: f64"))
51 .collect::<Vec<_>>()
52 .join(", ");
53 format!("fn f({args}) -> f64 {{ {} }}", rust_expr(&tree.root))
54}
55
56#[must_use]
58pub fn numpy_code(tree: &EmlTree) -> String {
59 let n = tree.num_vars().max(1);
60 let args = (0..n)
61 .map(|i| format!("x{i}"))
62 .collect::<Vec<_>>()
63 .join(", ");
64 format!("lambda {args}: {}", numpy_expr(&tree.root))
65}
66
67#[must_use]
73pub fn sympy_code(tree: &EmlTree) -> String {
74 let n = tree.num_vars().max(1);
75 let syms = (0..n)
76 .map(|i| format!("x{i}"))
77 .collect::<Vec<_>>()
78 .join(" ");
79 format!("# symbols('{syms}')\n{}", sympy_expr(&tree.root))
80}
81
82#[cfg(test)]
83mod tests {
84 use super::*;
85
86 #[test]
87 fn rust_codegen_for_exp() {
88 let tree = oxieml::Canonical::exp(&EmlTree::var(0));
89 let code = rust_code(&tree);
90 assert!(code.contains("fn f("));
91 assert!(code.contains(".exp()"));
92 }
93
94 #[test]
95 fn numpy_codegen_for_exp() {
96 let tree = oxieml::Canonical::exp(&EmlTree::var(0));
97 let code = numpy_code(&tree);
98 assert!(code.contains("np.exp"));
99 }
100
101 #[test]
102 fn sympy_codegen_for_exp() {
103 let tree = oxieml::Canonical::exp(&EmlTree::var(0));
104 let code = sympy_code(&tree);
105 assert!(code.contains("exp("));
106 assert!(code.contains("log("));
107 assert!(code.contains("x0"));
108 }
109}