1use proc_macro::TokenStream;
2use proc_macro2::TokenStream as TokenStream2;
3use quote::quote;
4use syn::{Expr, Ident, Token, parse_macro_input};
5
6#[cfg(feature = "optimization")]
8use egglog::EGraph;
9
10#[proc_macro]
37pub fn optimize_compile_time(input: TokenStream) -> TokenStream {
38 let input = parse_macro_input!(input as OptimizeInput);
39
40 let ast = match expr_to_ast(&input.expr) {
42 Ok(ast) => ast,
43 Err(e) => {
44 return syn::Error::new_spanned(
45 &input.expr,
46 format!("Failed to parse expression: {e}"),
47 )
48 .to_compile_error()
49 .into();
50 }
51 };
52
53 let optimized_ast = run_compile_time_optimization(&ast);
55
56 let generated_code = ast_to_rust_expr(&optimized_ast, &input.vars);
58
59 quote! {
61 {
62 #generated_code
63 }
64 }
65 .into()
66}
67
68struct OptimizeInput {
70 expr: Expr,
71 vars: Vec<Ident>,
72}
73
74impl syn::parse::Parse for OptimizeInput {
75 fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
76 let expr = input.parse()?;
77 input.parse::<Token![,]>()?;
78
79 let content;
80 syn::bracketed!(content in input);
81
82 let vars = content
83 .parse_terminated(Ident::parse, Token![,])?
84 .into_iter()
85 .collect();
86
87 Ok(OptimizeInput { expr, vars })
88 }
89}
90
91#[derive(Debug, Clone, PartialEq)]
93enum CompileTimeAST {
94 Variable(usize),
95 Constant(f64),
96 Add(Box<CompileTimeAST>, Box<CompileTimeAST>),
97 Mul(Box<CompileTimeAST>, Box<CompileTimeAST>),
98 Sub(Box<CompileTimeAST>, Box<CompileTimeAST>),
99 Sin(Box<CompileTimeAST>),
100 Cos(Box<CompileTimeAST>),
101 Exp(Box<CompileTimeAST>),
102 Ln(Box<CompileTimeAST>),
103 Pow(Box<CompileTimeAST>, Box<CompileTimeAST>),
104}
105
106impl CompileTimeAST {
107 #[cfg(feature = "optimization")]
109 fn to_egglog(&self) -> String {
110 match self {
111 CompileTimeAST::Variable(id) => format!("(Var \"x{id}\")"),
112 CompileTimeAST::Constant(val) => {
113 if val.fract() == 0.0 {
114 format!("(Num {val:.1})")
115 } else {
116 format!("(Num {val})")
117 }
118 }
119 CompileTimeAST::Add(left, right) => {
120 format!("(Add {} {})", left.to_egglog(), right.to_egglog())
121 }
122 CompileTimeAST::Mul(left, right) => {
123 format!("(Mul {} {})", left.to_egglog(), right.to_egglog())
124 }
125 CompileTimeAST::Sub(left, right) => {
126 format!("(Add {} (Neg {}))", left.to_egglog(), right.to_egglog())
128 }
129 CompileTimeAST::Sin(inner) => {
130 format!("(Sin {})", inner.to_egglog())
131 }
132 CompileTimeAST::Cos(inner) => {
133 format!("(Cos {})", inner.to_egglog())
134 }
135 CompileTimeAST::Exp(inner) => {
136 format!("(Exp {})", inner.to_egglog())
137 }
138 CompileTimeAST::Ln(inner) => {
139 format!("(Ln {})", inner.to_egglog())
140 }
141 CompileTimeAST::Pow(base, exp) => {
142 format!("(Pow {} {})", base.to_egglog(), exp.to_egglog())
143 }
144 }
145 }
146}
147
148fn expr_to_ast(expr: &Expr) -> Result<CompileTimeAST, String> {
150 match expr {
151 Expr::MethodCall(method_call) => {
153 let receiver_ast = expr_to_ast(&method_call.receiver)?;
154
155 match method_call.method.to_string().as_str() {
156 "add" => {
157 if method_call.args.len() != 1 {
158 return Err("add() requires exactly one argument".to_string());
159 }
160 let arg_ast = expr_to_ast(&method_call.args[0])?;
161 Ok(CompileTimeAST::Add(
162 Box::new(receiver_ast),
163 Box::new(arg_ast),
164 ))
165 }
166 "sub" => {
167 if method_call.args.len() != 1 {
168 return Err("sub() requires exactly one argument".to_string());
169 }
170 let arg_ast = expr_to_ast(&method_call.args[0])?;
171 Ok(CompileTimeAST::Sub(
172 Box::new(receiver_ast),
173 Box::new(arg_ast),
174 ))
175 }
176 "mul" => {
177 if method_call.args.len() != 1 {
178 return Err("mul() requires exactly one argument".to_string());
179 }
180 let arg_ast = expr_to_ast(&method_call.args[0])?;
181 Ok(CompileTimeAST::Mul(
182 Box::new(receiver_ast),
183 Box::new(arg_ast),
184 ))
185 }
186 "div" => {
187 if method_call.args.len() != 1 {
188 return Err("div() requires exactly one argument".to_string());
189 }
190 let arg_ast = expr_to_ast(&method_call.args[0])?;
191 Ok(CompileTimeAST::Mul(
193 Box::new(receiver_ast),
194 Box::new(CompileTimeAST::Pow(
195 Box::new(arg_ast),
196 Box::new(CompileTimeAST::Constant(-1.0)),
197 )),
198 ))
199 }
200 "pow" => {
201 if method_call.args.len() != 1 {
202 return Err("pow() requires exactly one argument".to_string());
203 }
204 let arg_ast = expr_to_ast(&method_call.args[0])?;
205 Ok(CompileTimeAST::Pow(
206 Box::new(receiver_ast),
207 Box::new(arg_ast),
208 ))
209 }
210 "sin" => {
211 if !method_call.args.is_empty() {
212 return Err("sin() takes no arguments".to_string());
213 }
214 Ok(CompileTimeAST::Sin(Box::new(receiver_ast)))
215 }
216 "cos" => {
217 if !method_call.args.is_empty() {
218 return Err("cos() takes no arguments".to_string());
219 }
220 Ok(CompileTimeAST::Cos(Box::new(receiver_ast)))
221 }
222 "exp" => {
223 if !method_call.args.is_empty() {
224 return Err("exp() takes no arguments".to_string());
225 }
226 Ok(CompileTimeAST::Exp(Box::new(receiver_ast)))
227 }
228 "ln" => {
229 if !method_call.args.is_empty() {
230 return Err("ln() takes no arguments".to_string());
231 }
232 Ok(CompileTimeAST::Ln(Box::new(receiver_ast)))
233 }
234 "sqrt" => {
235 if !method_call.args.is_empty() {
236 return Err("sqrt() takes no arguments".to_string());
237 }
238 Ok(CompileTimeAST::Pow(
240 Box::new(receiver_ast),
241 Box::new(CompileTimeAST::Constant(0.5)),
242 ))
243 }
244 "neg" => {
245 if !method_call.args.is_empty() {
246 return Err("neg() takes no arguments".to_string());
247 }
248 Ok(CompileTimeAST::Mul(
250 Box::new(CompileTimeAST::Constant(-1.0)),
251 Box::new(receiver_ast),
252 ))
253 }
254 _ => Err(format!("Unknown method: {}", method_call.method)),
255 }
256 }
257
258 Expr::Call(call) => {
260 if let Expr::Path(path) = &*call.func {
261 if let Some(segment) = path.path.segments.last() {
262 match segment.ident.to_string().as_str() {
263 "var" => {
264 if let syn::PathArguments::AngleBracketed(args) = &segment.arguments {
266 if let Some(syn::GenericArgument::Const(const_expr)) =
267 args.args.first()
268 {
269 if let Expr::Lit(syn::ExprLit {
270 lit: syn::Lit::Int(lit_int),
271 ..
272 }) = const_expr
273 {
274 let var_id: usize = lit_int
275 .base10_parse()
276 .map_err(|_| "Invalid variable ID".to_string())?;
277 return Ok(CompileTimeAST::Variable(var_id));
278 }
279 }
280 }
281 Err("Invalid var::<ID>() syntax".to_string())
282 }
283 "constant" => {
284 if call.args.len() != 1 {
285 return Err("constant() requires exactly one argument".to_string());
286 }
287
288 match &call.args[0] {
290 Expr::Lit(syn::ExprLit { lit: syn::Lit::Float(lit_float), .. }) => {
292 let value: f64 = lit_float.base10_parse()
293 .map_err(|_| "Invalid float literal".to_string())?;
294 Ok(CompileTimeAST::Constant(value))
295 }
296 Expr::Lit(syn::ExprLit { lit: syn::Lit::Int(lit_int), .. }) => {
298 let value: f64 = lit_int.base10_parse::<i64>()
299 .map_err(|_| "Invalid int literal".to_string())? as f64;
300 Ok(CompileTimeAST::Constant(value))
301 }
302 Expr::Unary(unary) => {
304 match unary.op {
305 syn::UnOp::Deref(_) => {
306 Err("constant() with variable dereference not supported in compile-time optimization".to_string())
310 }
311 syn::UnOp::Neg(_) => {
312 match &*unary.expr {
314 Expr::Lit(syn::ExprLit { lit: syn::Lit::Float(lit_float), .. }) => {
315 let value: f64 = lit_float.base10_parse()
316 .map_err(|_| "Invalid float literal".to_string())?;
317 Ok(CompileTimeAST::Constant(-value))
318 }
319 Expr::Lit(syn::ExprLit { lit: syn::Lit::Int(lit_int), .. }) => {
320 let value: f64 = lit_int.base10_parse::<i64>()
321 .map_err(|_| "Invalid int literal".to_string())? as f64;
322 Ok(CompileTimeAST::Constant(-value))
323 }
324 _ => Err("constant() with complex negative expression not supported".to_string())
325 }
326 }
327 _ => Err("constant() with unsupported unary operator".to_string())
328 }
329 }
330 _ => Err("constant() argument must be a numeric literal (variables not supported in compile-time optimization)".to_string())
332 }
333 }
334 _ => Err(format!("Unknown function: {}", segment.ident)),
335 }
336 } else {
337 Err("Invalid function call".to_string())
338 }
339 } else {
340 Err("Complex function calls not supported".to_string())
341 }
342 }
343
344 _ => Err("Unsupported expression type".to_string()),
345 }
346}
347
348#[cfg(feature = "optimization")]
350fn run_compile_time_optimization(ast: &CompileTimeAST) -> CompileTimeAST {
351 let mut egraph = match create_egglog_with_math_rules() {
353 Ok(egraph) => egraph,
354 Err(_) => return ast.clone(), };
356
357 let egglog_expr = ast.to_egglog();
359 let expr_id = "expr_0";
360
361 let add_command = format!("(let {expr_id} {egglog_expr})");
363 if egraph.parse_and_run_program(None, &add_command).is_err() {
364 return ast.clone(); }
366
367 if egraph.parse_and_run_program(None, "(run 3)").is_err() {
369 return ast.clone(); }
371
372 let extract_command = format!("(extract {expr_id})");
374 match egraph.parse_and_run_program(None, &extract_command) {
375 Ok(result) => {
376 let output_string = result.join("\n");
378 parse_egglog_result(&output_string).unwrap_or_else(|_| ast.clone())
379 }
380 Err(_) => ast.clone(), }
382}
383
384#[cfg(feature = "optimization")]
386fn create_egglog_with_math_rules() -> Result<EGraph, String> {
387 let mut egraph = EGraph::default();
388
389 let program = r"
391; Mathematical expression datatype
392(datatype Math
393 (Num f64)
394 (Var String)
395 (Add Math Math)
396 (Mul Math Math)
397 (Neg Math)
398 (Pow Math Math)
399 (Ln Math)
400 (Exp Math)
401 (Sin Math)
402 (Cos Math))
403
404; SAFE SIMPLIFICATION RULES (no expansion)
405; Identity rules
406(rewrite (Add a (Num 0.0)) a)
407(rewrite (Add (Num 0.0) a) a)
408(rewrite (Mul a (Num 1.0)) a)
409(rewrite (Mul (Num 1.0) a) a)
410(rewrite (Mul a (Num 0.0)) (Num 0.0))
411(rewrite (Mul (Num 0.0) a) (Num 0.0))
412(rewrite (Pow a (Num 0.0)) (Num 1.0))
413(rewrite (Pow a (Num 1.0)) a)
414
415; SAFE transcendental identities (only simplifying)
416(rewrite (Ln (Exp x)) x)
417; Remove the problematic expansion rule: (rewrite (Exp (Add a b)) (Mul (Exp a) (Exp b)))
418
419; SAFE specific patterns (no general commutativity/associativity)
420(rewrite (Ln (Mul (Exp a) (Exp b))) (Add a b))
421
422; Double negation
423(rewrite (Neg (Neg x)) x)
424
425; Power simplifications
426(rewrite (Pow (Exp x) y) (Exp (Mul x y)))
427(rewrite (Pow x (Num 0.5)) (Sqrt x))
428";
429
430 egraph
431 .parse_and_run_program(None, program)
432 .map_err(|e| format!("Failed to initialize egglog: {e}"))?;
433
434 Ok(egraph)
435}
436
437#[cfg(feature = "optimization")]
439fn parse_egglog_result(output: &str) -> Result<CompileTimeAST, String> {
440 let cleaned = output.trim();
441 parse_sexpr(cleaned)
442}
443
444#[cfg(feature = "optimization")]
446fn parse_sexpr(s: &str) -> Result<CompileTimeAST, String> {
447 let s = s.trim();
448
449 if !s.starts_with('(') || !s.ends_with(')') {
450 return Err("Invalid s-expression format".to_string());
451 }
452
453 let inner = &s[1..s.len() - 1];
454 let tokens = tokenize_sexpr(inner)?;
455
456 if tokens.is_empty() {
457 return Err("Empty s-expression".to_string());
458 }
459
460 match tokens[0].as_str() {
461 "Num" => {
462 if tokens.len() != 2 {
463 return Err("Num requires exactly one argument".to_string());
464 }
465 let value: f64 = tokens[1]
466 .parse()
467 .map_err(|_| "Invalid number format".to_string())?;
468 Ok(CompileTimeAST::Constant(value))
469 }
470 "Var" => {
471 if tokens.len() != 2 {
472 return Err("Var requires exactly one argument".to_string());
473 }
474 let var_name = tokens[1].trim_matches('"');
476 if !var_name.starts_with('x') {
477 return Err("Invalid variable name format".to_string());
478 }
479 let index: usize = var_name[1..]
480 .parse()
481 .map_err(|_| "Invalid variable index".to_string())?;
482 Ok(CompileTimeAST::Variable(index))
483 }
484 "Add" => {
485 if tokens.len() != 3 {
486 return Err("Add requires exactly two arguments".to_string());
487 }
488 let left = parse_sexpr(&tokens[1])?;
489 let right = parse_sexpr(&tokens[2])?;
490 Ok(CompileTimeAST::Add(Box::new(left), Box::new(right)))
491 }
492 "Mul" => {
493 if tokens.len() != 3 {
494 return Err("Mul requires exactly two arguments".to_string());
495 }
496 let left = parse_sexpr(&tokens[1])?;
497 let right = parse_sexpr(&tokens[2])?;
498 Ok(CompileTimeAST::Mul(Box::new(left), Box::new(right)))
499 }
500 "Neg" => {
501 if tokens.len() != 2 {
502 return Err("Neg requires exactly one argument".to_string());
503 }
504 let inner = parse_sexpr(&tokens[1])?;
505 Ok(CompileTimeAST::Mul(
507 Box::new(CompileTimeAST::Constant(-1.0)),
508 Box::new(inner),
509 ))
510 }
511 "Pow" => {
512 if tokens.len() != 3 {
513 return Err("Pow requires exactly two arguments".to_string());
514 }
515 let base = parse_sexpr(&tokens[1])?;
516 let exp = parse_sexpr(&tokens[2])?;
517 Ok(CompileTimeAST::Pow(Box::new(base), Box::new(exp)))
518 }
519 "Ln" => {
520 if tokens.len() != 2 {
521 return Err("Ln requires exactly one argument".to_string());
522 }
523 let inner = parse_sexpr(&tokens[1])?;
524 Ok(CompileTimeAST::Ln(Box::new(inner)))
525 }
526 "Exp" => {
527 if tokens.len() != 2 {
528 return Err("Exp requires exactly one argument".to_string());
529 }
530 let inner = parse_sexpr(&tokens[1])?;
531 Ok(CompileTimeAST::Exp(Box::new(inner)))
532 }
533 "Sin" => {
534 if tokens.len() != 2 {
535 return Err("Sin requires exactly one argument".to_string());
536 }
537 let inner = parse_sexpr(&tokens[1])?;
538 Ok(CompileTimeAST::Sin(Box::new(inner)))
539 }
540 "Cos" => {
541 if tokens.len() != 2 {
542 return Err("Cos requires exactly one argument".to_string());
543 }
544 let inner = parse_sexpr(&tokens[1])?;
545 Ok(CompileTimeAST::Cos(Box::new(inner)))
546 }
547 _ => Err(format!("Unknown function: {}", tokens[0])),
548 }
549}
550
551#[cfg(feature = "optimization")]
553fn tokenize_sexpr(s: &str) -> Result<Vec<String>, String> {
554 let mut tokens = Vec::new();
555 let mut current_token = String::new();
556 let mut paren_depth = 0;
557 let mut in_quotes = false;
558
559 for ch in s.chars() {
560 match ch {
561 '"' => {
562 in_quotes = !in_quotes;
563 current_token.push(ch);
564 }
565 '(' if !in_quotes => {
566 if paren_depth == 0 && !current_token.is_empty() {
567 tokens.push(current_token.trim().to_string());
568 current_token.clear();
569 }
570 paren_depth += 1;
571 current_token.push(ch);
572 }
573 ')' if !in_quotes => {
574 paren_depth -= 1;
575 current_token.push(ch);
576 if paren_depth == 0 {
577 tokens.push(current_token.trim().to_string());
578 current_token.clear();
579 }
580 }
581 ' ' | '\t' | '\n' if !in_quotes && paren_depth == 0 => {
582 if !current_token.is_empty() {
583 tokens.push(current_token.trim().to_string());
584 current_token.clear();
585 }
586 }
587 _ => {
588 current_token.push(ch);
589 }
590 }
591 }
592
593 if !current_token.is_empty() {
594 tokens.push(current_token.trim().to_string());
595 }
596
597 Ok(tokens)
598}
599
600fn ast_to_rust_expr(ast: &CompileTimeAST, vars: &[Ident]) -> TokenStream2 {
602 match ast {
603 CompileTimeAST::Constant(c) => {
604 quote! { #c }
605 }
606 CompileTimeAST::Variable(idx) => {
607 if *idx < vars.len() {
608 let var = &vars[*idx];
609 quote! { #var }
610 } else {
611 quote! { 0.0 }
612 }
613 }
614 CompileTimeAST::Add(left, right) => {
615 let left_expr = ast_to_rust_expr(left, vars);
616 let right_expr = ast_to_rust_expr(right, vars);
617 quote! { (#left_expr + #right_expr) }
618 }
619 CompileTimeAST::Sub(left, right) => {
620 let left_expr = ast_to_rust_expr(left, vars);
621 let right_expr = ast_to_rust_expr(right, vars);
622 quote! { (#left_expr - #right_expr) }
623 }
624 CompileTimeAST::Mul(left, right) => {
625 let left_expr = ast_to_rust_expr(left, vars);
626 let right_expr = ast_to_rust_expr(right, vars);
627 quote! { (#left_expr * #right_expr) }
628 }
629 CompileTimeAST::Pow(base, exp) => {
630 let base_expr = ast_to_rust_expr_with_parens(base, vars);
631 let exp_expr = ast_to_rust_expr(exp, vars);
632 quote! { #base_expr.powf(#exp_expr) }
633 }
634 CompileTimeAST::Sin(inner) => {
635 let inner_expr = ast_to_rust_expr_with_parens(inner, vars);
636 quote! { #inner_expr.sin() }
637 }
638 CompileTimeAST::Cos(inner) => {
639 let inner_expr = ast_to_rust_expr_with_parens(inner, vars);
640 quote! { #inner_expr.cos() }
641 }
642 CompileTimeAST::Exp(inner) => {
643 let inner_expr = ast_to_rust_expr_with_parens(inner, vars);
644 quote! { #inner_expr.exp() }
645 }
646 CompileTimeAST::Ln(inner) => {
647 let inner_expr = ast_to_rust_expr_with_parens(inner, vars);
648 quote! { #inner_expr.ln() }
649 }
650 }
651}
652
653fn ast_to_rust_expr_with_parens(ast: &CompileTimeAST, vars: &[Ident]) -> TokenStream2 {
655 match ast {
656 CompileTimeAST::Constant(_) | CompileTimeAST::Variable(_) => ast_to_rust_expr(ast, vars),
658 _ => {
660 let expr = ast_to_rust_expr(ast, vars);
661 quote! { (#expr) }
662 }
663 }
664}
665
666#[cfg(not(feature = "optimization"))]
668fn run_compile_time_optimization(ast: &CompileTimeAST) -> CompileTimeAST {
669 apply_basic_optimizations(ast)
671}
672
673#[cfg(not(feature = "optimization"))]
675fn apply_basic_optimizations(ast: &CompileTimeAST) -> CompileTimeAST {
676 match ast {
677 CompileTimeAST::Add(left, right) => {
679 let left_opt = apply_basic_optimizations(left);
680 let right_opt = apply_basic_optimizations(right);
681
682 if let CompileTimeAST::Constant(0.0) = right_opt {
683 left_opt
684 } else if let CompileTimeAST::Constant(0.0) = left_opt {
685 right_opt
686 } else {
687 CompileTimeAST::Add(Box::new(left_opt), Box::new(right_opt))
688 }
689 }
690 CompileTimeAST::Mul(left, right) => {
692 let left_opt = apply_basic_optimizations(left);
693 let right_opt = apply_basic_optimizations(right);
694
695 if let CompileTimeAST::Constant(1.0) = right_opt {
696 left_opt
697 } else if let CompileTimeAST::Constant(1.0) = left_opt {
698 right_opt
699 } else if let CompileTimeAST::Constant(0.0) = right_opt {
700 CompileTimeAST::Constant(0.0)
701 } else if let CompileTimeAST::Constant(0.0) = left_opt {
702 CompileTimeAST::Constant(0.0)
703 } else {
704 CompileTimeAST::Mul(Box::new(left_opt), Box::new(right_opt))
705 }
706 }
707 CompileTimeAST::Ln(inner) => {
709 let inner_opt = apply_basic_optimizations(inner);
710 if let CompileTimeAST::Exp(exp_inner) = &inner_opt {
711 (**exp_inner).clone()
712 } else {
713 CompileTimeAST::Ln(Box::new(inner_opt))
714 }
715 }
716 CompileTimeAST::Sub(left, right) => CompileTimeAST::Sub(
718 Box::new(apply_basic_optimizations(left)),
719 Box::new(apply_basic_optimizations(right)),
720 ),
721 CompileTimeAST::Pow(base, exp) => CompileTimeAST::Pow(
722 Box::new(apply_basic_optimizations(base)),
723 Box::new(apply_basic_optimizations(exp)),
724 ),
725 CompileTimeAST::Sin(inner) => {
726 CompileTimeAST::Sin(Box::new(apply_basic_optimizations(inner)))
727 }
728 CompileTimeAST::Cos(inner) => {
729 CompileTimeAST::Cos(Box::new(apply_basic_optimizations(inner)))
730 }
731 CompileTimeAST::Exp(inner) => {
732 CompileTimeAST::Exp(Box::new(apply_basic_optimizations(inner)))
733 }
734 CompileTimeAST::Variable(_) | CompileTimeAST::Constant(_) => ast.clone(),
736 }
737}