use crate::lower::LoweredOp;
use crate::tree::EmlTree;
pub fn compile_to_rust(tree: &EmlTree, fn_name: &str) -> String {
let lowered = tree.lower();
let simplified = lowered.simplify();
compile_lowered_to_rust(&simplified, fn_name, tree.num_vars())
}
pub fn compile_lowered_to_rust(op: &LoweredOp, fn_name: &str, num_vars: usize) -> String {
let mut code = String::new();
code.push_str(&format!("fn {fn_name}(vars: &[f64]) -> f64 {{\n"));
for i in 0..num_vars {
code.push_str(&format!(" let x{i} = vars[{i}];\n"));
}
if num_vars > 0 {
code.push('\n');
}
let expr = emit_rust_expr(op);
code.push_str(&format!(" {expr}\n"));
code.push_str("}\n");
code
}
pub fn compile_to_closure(tree: &EmlTree) -> String {
let lowered = tree.lower();
let simplified = lowered.simplify();
let num_vars = tree.num_vars();
let mut code = String::from("|vars: &[f64]| -> f64 {\n");
for i in 0..num_vars {
code.push_str(&format!(" let x{i} = vars[{i}];\n"));
}
if num_vars > 0 {
code.push('\n');
}
let expr = emit_rust_expr(&simplified);
code.push_str(&format!(" {expr}\n"));
code.push('}');
code
}
pub fn compile_to_rust_batch(tree: &EmlTree, fn_name: &str) -> String {
let single_point = compile_to_rust(tree, fn_name);
let batch_body = if cfg!(feature = "parallel") {
format!(
"fn {fn_name}_batch(data: &[Vec<f64>]) -> Vec<f64> {{\n\
use rayon::prelude::*;\n\
data.par_iter().map(|pt| {fn_name}(pt)).collect()\n\
}}\n"
)
} else {
format!(
"fn {fn_name}_batch(data: &[Vec<f64>]) -> Vec<f64> {{\n\
data.iter().map(|pt| {fn_name}(pt)).collect()\n\
}}\n"
)
};
format!("{single_point}\n{batch_body}")
}
fn emit_rust_expr(op: &LoweredOp) -> String {
match op {
LoweredOp::NamedConst(nc) => {
let c = nc.value();
if (c - std::f64::consts::E).abs() < 1e-15 {
"std::f64::consts::E".to_string()
} else if (c - std::f64::consts::PI).abs() < 1e-15 {
"std::f64::consts::PI".to_string()
} else if (c - std::f64::consts::SQRT_2).abs() < 1e-15 {
"std::f64::consts::SQRT_2".to_string()
} else if (c - c.round()).abs() < 1e-10 && c.abs() < 1e15 {
format!("{}_f64", c as i64)
} else {
format!("{c:.15e}_f64")
}
}
LoweredOp::Const(c) => {
if (c - std::f64::consts::E).abs() < 1e-15 {
"std::f64::consts::E".to_string()
} else if (c - std::f64::consts::PI).abs() < 1e-15 {
"std::f64::consts::PI".to_string()
} else if (c - c.round()).abs() < 1e-10 && c.abs() < 1e15 {
format!("{}_f64", *c as i64)
} else {
format!("{c:.15e}_f64")
}
}
LoweredOp::Var(i) => format!("x{i}"),
LoweredOp::Add(a, b) => {
format!("({} + {})", emit_rust_expr(a), emit_rust_expr(b))
}
LoweredOp::Sub(a, b) => {
format!("({} - {})", emit_rust_expr(a), emit_rust_expr(b))
}
LoweredOp::Mul(a, b) => {
format!("({} * {})", emit_rust_expr(a), emit_rust_expr(b))
}
LoweredOp::Div(a, b) => {
format!("({} / {})", emit_rust_expr(a), emit_rust_expr(b))
}
LoweredOp::Exp(a) => {
format!("({}).exp()", emit_rust_expr(a))
}
LoweredOp::Ln(a) => {
format!("({}).ln()", emit_rust_expr(a))
}
LoweredOp::Sin(a) => {
format!("({}).sin()", emit_rust_expr(a))
}
LoweredOp::Cos(a) => {
format!("({}).cos()", emit_rust_expr(a))
}
LoweredOp::Pow(a, b) => {
format!("({}).powf({})", emit_rust_expr(a), emit_rust_expr(b))
}
LoweredOp::Neg(a) => {
format!("(-({}))", emit_rust_expr(a))
}
LoweredOp::Tan(a) => {
format!("({}).tan()", emit_rust_expr(a))
}
LoweredOp::Sinh(a) => {
format!("({}).sinh()", emit_rust_expr(a))
}
LoweredOp::Cosh(a) => {
format!("({}).cosh()", emit_rust_expr(a))
}
LoweredOp::Tanh(a) => {
format!("({}).tanh()", emit_rust_expr(a))
}
LoweredOp::Arcsin(a) => {
format!("({}).asin()", emit_rust_expr(a))
}
LoweredOp::Arccos(a) => {
format!("({}).acos()", emit_rust_expr(a))
}
LoweredOp::Arctan(a) => {
format!("({}).atan()", emit_rust_expr(a))
}
LoweredOp::Arcsinh(a) => {
format!("({}).asinh()", emit_rust_expr(a))
}
LoweredOp::Arccosh(a) => {
format!("({}).acosh()", emit_rust_expr(a))
}
LoweredOp::Arctanh(a) => {
format!("({}).atanh()", emit_rust_expr(a))
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_compile_exp() {
let x = EmlTree::var(0);
let one = EmlTree::one();
let exp_x = EmlTree::eml(&x, &one);
let code = compile_to_rust(&exp_x, "exp_fn");
assert!(code.contains("fn exp_fn"));
assert!(code.contains("x0"));
assert!(code.contains(".exp()"));
}
#[test]
fn test_compile_euler() {
let one = EmlTree::one();
let e = EmlTree::eml(&one, &one);
let code = compile_to_rust(&e, "euler_fn");
assert!(code.contains("fn euler_fn"));
assert!(code.contains("E") || code.contains("exp"));
}
#[test]
fn test_compile_closure() {
let x = EmlTree::var(0);
let one = EmlTree::one();
let exp_x = EmlTree::eml(&x, &one);
let code = compile_to_closure(&exp_x);
assert!(code.contains("|vars: &[f64]| -> f64"));
}
#[test]
fn test_compile_no_vars() {
let one = EmlTree::one();
let code = compile_to_rust(&one, "const_fn");
assert!(code.contains("fn const_fn"));
assert!(!code.contains("let x"));
}
#[test]
fn test_compile_to_rust_batch() {
let x = EmlTree::var(0);
let one = EmlTree::one();
let exp_x = EmlTree::eml(&x, &one);
let code = compile_to_rust_batch(&exp_x, "exp_fn");
assert!(code.contains("fn exp_fn(vars: &[f64]) -> f64"));
assert!(code.contains("fn exp_fn_batch(data: &[Vec<f64>]) -> Vec<f64>"));
assert!(code.contains(".collect()"));
#[cfg(feature = "parallel")]
assert!(code.contains("par_iter"));
#[cfg(not(feature = "parallel"))]
assert!(code.contains("data.iter()"));
}
}