mathhook_macros/
lib.rs

1//! Code generation macros for MathHook language bindings and expression construction
2//!
3//! This crate provides procedural macros to generate Python (PyO3) and Node.js (NAPI)
4//! bindings from a single function definition, eliminating code duplication while
5//! maintaining zero-cost abstractions.
6//!
7//! Also provides procedural macros for mathematical expression and symbol construction.
8//!
9//! # Supported Function Arities
10//!
11//! - **Unary**: sin(x), cos(x), exp(x)
12//! - **Binary**: pow(x, y), atan2(y, x)
13//! - **Variadic**: add(args...), mul(args...)
14//! - **Constants**: pi(), e()
15
16use proc_macro::TokenStream;
17use quote::quote;
18use syn::{parse_macro_input, Ident};
19
20mod expr;
21mod function;
22mod symbol;
23
24/// Procedural macro for creating mathematical expressions with full syntax support
25///
26/// # Supported Operators
27///
28/// - **Basic arithmetic**: `+`, `-`, `*`, `/`
29/// - **Unary negation**: `-x`
30/// - **Power operations**: Both `**` and `.pow()` syntax
31/// - **Comparison operators**: `==`, `<`, `>`, `<=`, `>=`
32///
33/// # Supported Literals
34///
35/// - **Integers**: `42`, `-5`, `0`
36/// - **Floats**: `3.14`, `-2.5`, `0.0`
37/// - **Identifiers/symbols**: `x`, `theta`, `alpha_1`
38/// - **Mathematical constants**: `pi`, `e`, `i`
39///
40/// # Supported Constructs
41///
42/// - **Function calls**: `sin(x)`, `log(x, y)`, `f(a, b, c)`
43/// - **Parenthesized expressions**: `(2*x + 3)`, `((x+y))`
44/// - **Method calls**: `x.pow(2)`, `x.abs()`, `x.sqrt()`, `x.simplify()`
45///
46/// # Power Operator Syntax
47///
48/// The macro supports **two syntaxes** for exponentiation:
49///
50/// ## Method Syntax: `.pow()`
51///
52/// The `.pow()` method works everywhere without restrictions:
53///
54/// ```rust,ignore
55/// expr!(x.pow(2))              // x squared
56/// expr!(x.pow(y))              // x to the y
57/// expr!((x + 1).pow(3))        // (x+1) cubed
58/// expr!(2 * x.pow(3) + 5)      // Complex expression
59/// ```
60///
61/// **Advantages:**
62/// - Mirrors Rust's standard library API
63/// - Works seamlessly in any context
64/// - Clear, unambiguous precedence
65/// - No parentheses required
66///
67/// ## Infix Syntax: `**`
68///
69/// The `**` operator provides mathematical notation:
70///
71/// ```rust,ignore
72/// expr!(x ** 2)                // x squared
73/// expr!(x ** y)                // x to the y
74/// expr!((x + 1) ** 3)          // (x+1) cubed
75/// ```
76///
77/// **IMPORTANT - When to use parentheses with `**`:**
78///
79/// Due to token-level preprocessing, complex expressions with `**` should use
80/// parentheses for clarity when mixing with other operators:
81///
82/// ```rust,ignore
83/// // ✅ RECOMMENDED (clear precedence):
84/// expr!(2 * (x ** 3) + 5)      // Parentheses make precedence explicit
85/// expr!((x ** 2) + (y ** 2))   // Clear grouping
86///
87/// // ⚠️  WORKS BUT LESS CLEAR:
88/// expr!(x ** 2 + y ** 2)       // Relies on implicit precedence
89/// ```
90///
91/// **Right-associativity:**
92///
93/// Power operations are right-associative (like standard mathematical notation):
94///
95/// ```rust,ignore
96/// expr!(x ** 2 ** 3)           // Parsed as: x ** (2 ** 3) = x ** 8
97/// expr!(2 ** 3 ** 2)           // Parsed as: 2 ** (3 ** 2) = 2 ** 9 = 512
98/// ```
99///
100/// **Which syntax to use?**
101///
102/// - Use `.pow()` for guaranteed clarity in complex expressions
103/// - Use `**` when writing simple mathematical formulas that match notation
104/// - When in doubt, add parentheses: `expr!(2 * (x ** 3) + 5)`
105///
106/// # Examples
107///
108/// ```rust,ignore
109/// use mathhook_macros::expr;
110/// use mathhook_core::{Expression, symbol};
111///
112/// let x = symbol!(x);
113/// let y = symbol!(y);
114///
115/// // Basic operations
116/// let sum = expr!(x + 2);
117/// let product = expr!(2 * x);
118///
119/// // Power operations (both syntaxes work)
120/// let power1 = expr!(x ** 2);
121/// let power2 = expr!(x.pow(2));
122///
123/// // Complex expressions (use parentheses with ** for clarity)
124/// let quadratic = expr!((x ** 2) + 2*x + 1);
125/// let nested = expr!((x + 1) * (x - 1));
126/// let functions = expr!(sin(x ** 2));
127///
128/// // Comparison operators
129/// let equation = expr!(x ** 2 == 4);
130/// let inequality = expr!(x + 1 > y);
131///
132/// // Method calls
133/// let abs_val = expr!(x.abs());
134/// let sqrt_expr = expr!(x.sqrt());
135/// let simplified = expr!((x + x).simplify());
136/// ```
137///
138/// # Precedence Rules
139///
140/// Operator precedence from highest to lowest:
141///
142/// 1. **Method calls** (highest): `.pow()`, `.abs()`, `.sqrt()`, `.simplify()`
143/// 2. **Power** (right-associative): `**`
144/// 3. **Unary negation**: `-x`
145/// 4. **Multiplication/division**: `*`, `/`
146/// 5. **Addition/subtraction**: `+`, `-`
147/// 6. **Comparison operators** (lowest): `==`, `<`, `>`, `<=`, `>=`
148///
149/// Use parentheses to override precedence:
150///
151/// ```rust,ignore
152/// expr!(2 * x + 3)             // Parsed as: (2*x) + 3
153/// expr!((2 + 3) * x)           // Parentheses override: 5*x
154/// expr!(2 * (x ** 3))          // Recommended for clarity with **
155/// expr!((x + 1) ** 2)          // Power of sum
156/// ```
157#[proc_macro]
158pub fn expr(input: TokenStream) -> TokenStream {
159    expr::expr_impl(input)
160}
161
162/// Procedural macro for creating symbols with optional type specification
163///
164/// # Syntax
165///
166/// ```rust,ignore
167/// symbol!(x)                  // Scalar (default, commutative)
168/// symbol!(A; matrix)          // Matrix (noncommutative)
169/// symbol!(p; operator)        // Operator (noncommutative)
170/// symbol!(i; quaternion)      // Quaternion (noncommutative)
171/// symbol!("name")             // String literal for symbol name
172/// ```
173///
174/// # Symbol Types
175///
176/// - **scalar**: Commutative symbols (default) - variables like x, y, z
177/// - **matrix**: Noncommutative matrix symbols - A*B ≠ B*A
178/// - **operator**: Noncommutative operator symbols - for quantum mechanics [x,p] ≠ 0
179/// - **quaternion**: Noncommutative quaternion symbols - i*j = k, j*i = -k
180///
181/// # Examples
182///
183/// ```rust,ignore
184/// use mathhook_macros::symbol;
185///
186/// // Scalar symbols (commutative)
187/// let x = symbol!(x);
188/// let theta = symbol!(theta);
189///
190/// // Matrix symbols (noncommutative)
191/// let A = symbol!(A; matrix);
192/// let B = symbol!(B; matrix);
193///
194/// // Operator symbols (noncommutative)
195/// let p = symbol!(p; operator);
196/// let x_op = symbol!(x; operator);
197///
198/// // Quaternion symbols (noncommutative)
199/// let i = symbol!(i; quaternion);
200/// let j = symbol!(j; quaternion);
201/// ```
202#[proc_macro]
203pub fn symbol(input: TokenStream) -> TokenStream {
204    symbol::symbol_impl(input)
205}
206
207/// Procedural macro for creating multiple symbols at once
208///
209/// # Syntax
210///
211/// ```rust,ignore
212/// symbols![x, y, z]              // All scalars (default)
213/// symbols![A, B, C => matrix]    // All matrices
214/// symbols![p, x, H => operator]  // All operators
215/// symbols![i, j, k => quaternion] // All quaternions
216/// ```
217///
218/// # Returns
219///
220/// Returns `Vec<Symbol>` containing all created symbols.
221///
222/// # Examples
223///
224/// ```rust,ignore
225/// use mathhook_macros::symbols;
226///
227/// // Scalar symbols (default, commutative)
228/// let syms = symbols![x, y, z];
229/// assert_eq!(syms.len(), 3);
230///
231/// // Matrix symbols (noncommutative)
232/// let mats = symbols![A, B, C => matrix];
233/// assert_eq!(mats.len(), 3);
234///
235/// // Operator symbols (noncommutative)
236/// let ops = symbols![p, x, H => operator];
237/// assert_eq!(ops.len(), 3);
238///
239/// // Quaternion symbols (noncommutative)
240/// let quats = symbols![i, j, k => quaternion];
241/// assert_eq!(quats.len(), 3);
242/// ```
243#[proc_macro]
244pub fn symbols(input: TokenStream) -> TokenStream {
245    symbol::symbols_impl(input)
246}
247
248/// Procedural macro for creating function expressions
249///
250/// # Syntax
251///
252/// ```rust,ignore
253/// function!(sin)              // Zero args
254/// function!(sin, x)           // One arg (x is an Expression)
255/// function!(log, x, y)        // Two args
256/// function!(f, a, b, c)       // N args
257/// ```
258///
259/// # Examples
260///
261/// ```rust,ignore
262/// use mathhook_macros::function;
263/// use mathhook_core::expr;
264///
265/// // Zero-argument function
266/// let gamma_call = function!(gamma);
267///
268/// // Single argument
269/// let x = expr!(x);
270/// let sin_x = function!(sin, x);
271///
272/// // Multiple arguments
273/// let log_xy = function!(log, x, expr!(2));
274/// ```
275#[proc_macro]
276pub fn function(input: TokenStream) -> TokenStream {
277    function::function_impl(input)
278}
279
280/// Generate Python (PyO3) binding for a unary mathematical function
281///
282/// # Arguments
283///
284/// * `name` - The function name (e.g., `sin`)
285///
286/// # Examples
287///
288/// ```rust,ignore
289/// generate_python_binding!(sin);
290/// ```
291///
292/// Generates:
293///
294/// ```rust,ignore
295/// #[pyfunction]
296/// pub fn sin(x: &Bound<'_, PyAny>) -> PyResult<PyExpression> {
297///     let expr = sympify_python(x)?;
298///     Ok(PyExpression {
299///         inner: Expression::function("sin", vec![expr]),
300///     })
301/// }
302/// ```
303#[proc_macro]
304pub fn generate_python_binding(input: TokenStream) -> TokenStream {
305    let name = parse_macro_input!(input as Ident);
306    let name_str = name.to_string();
307
308    let expanded = quote! {
309        #[pyo3::pyfunction]
310        pub fn #name(x: &pyo3::Bound<'_, pyo3::PyAny>) -> pyo3::PyResult<crate::PyExpression> {
311            use mathhook_core::Expression;
312            let expr = crate::helpers::sympify_python(x)?;
313            Ok(crate::PyExpression {
314                inner: Expression::function(#name_str, vec![expr]),
315            })
316        }
317    };
318
319    TokenStream::from(expanded)
320}
321
322/// Generate Python (PyO3) binding for a binary mathematical function
323///
324/// # Arguments
325///
326/// * `name` - The function name (e.g., `pow`)
327///
328/// # Examples
329///
330/// ```rust,ignore
331/// generate_python_binary_binding!(pow);
332/// ```
333///
334/// Generates:
335///
336/// ```rust,ignore
337/// #[pyfunction]
338/// pub fn pow(x: &Bound<'_, PyAny>, y: &Bound<'_, PyAny>) -> PyResult<PyExpression> {
339///     let expr1 = sympify_python(x)?;
340///     let expr2 = sympify_python(y)?;
341///     Ok(PyExpression {
342///         inner: Expression::function("pow", vec![expr1, expr2]),
343///     })
344/// }
345/// ```
346#[proc_macro]
347pub fn generate_python_binary_binding(input: TokenStream) -> TokenStream {
348    let name = parse_macro_input!(input as Ident);
349    let name_str = name.to_string();
350
351    let expanded = quote! {
352        #[pyo3::pyfunction]
353        pub fn #name(
354            x: &pyo3::Bound<'_, pyo3::PyAny>,
355            y: &pyo3::Bound<'_, pyo3::PyAny>,
356        ) -> pyo3::PyResult<crate::PyExpression> {
357            use mathhook_core::Expression;
358            let expr1 = crate::helpers::sympify_python(x)?;
359            let expr2 = crate::helpers::sympify_python(y)?;
360            Ok(crate::PyExpression {
361                inner: Expression::function(#name_str, vec![expr1, expr2]),
362            })
363        }
364    };
365
366    TokenStream::from(expanded)
367}
368
369/// Generate Python (PyO3) binding for a variadic mathematical function
370///
371/// # Arguments
372///
373/// * `name` - The function name (e.g., `add`)
374///
375/// # Examples
376///
377/// ```rust,ignore
378/// generate_python_variadic_binding!(add);
379/// ```
380///
381/// Generates:
382///
383/// ```rust,ignore
384/// #[pyfunction]
385/// #[pyo3(signature = (*args))]
386/// pub fn add(args: &Bound<'_, PyTuple>) -> PyResult<PyExpression> {
387///     let mut exprs = Vec::new();
388///     for arg in args {
389///         exprs.push(sympify_python(&arg)?);
390///     }
391///     Ok(PyExpression {
392///         inner: Expression::function("add", exprs),
393///     })
394/// }
395/// ```
396#[proc_macro]
397pub fn generate_python_variadic_binding(input: TokenStream) -> TokenStream {
398    let name = parse_macro_input!(input as Ident);
399    let name_str = name.to_string();
400
401    let expanded = quote! {
402        #[pyo3::pyfunction]
403        #[pyo3(signature = (*args))]
404        pub fn #name(args: &pyo3::Bound<'_, pyo3::types::PyTuple>) -> pyo3::PyResult<crate::PyExpression> {
405            use mathhook_core::Expression;
406            let mut exprs = Vec::new();
407            for arg in args {
408                exprs.push(crate::helpers::sympify_python(&arg)?);
409            }
410            Ok(crate::PyExpression {
411                inner: Expression::function(#name_str, exprs),
412            })
413        }
414    };
415
416    TokenStream::from(expanded)
417}
418
419/// Generate Python (PyO3) binding for a zero-argument constant function
420///
421/// # Arguments
422///
423/// * `name` - The constant name (e.g., `pi`)
424///
425/// # Examples
426///
427/// ```rust,ignore
428/// generate_python_constant_binding!(pi);
429/// ```
430///
431/// Generates:
432///
433/// ```rust,ignore
434/// #[pyfunction]
435/// pub fn pi() -> PyExpression {
436///     PyExpression {
437///         inner: Expression::pi(),
438///     }
439/// }
440/// ```
441#[proc_macro]
442pub fn generate_python_constant_binding(input: TokenStream) -> TokenStream {
443    let name = parse_macro_input!(input as Ident);
444
445    let expanded = quote! {
446        #[pyo3::pyfunction]
447        pub fn #name() -> crate::PyExpression {
448            use mathhook_core::Expression;
449            crate::PyExpression {
450                inner: Expression::#name(),
451            }
452        }
453    };
454
455    TokenStream::from(expanded)
456}
457
458/// Generate Node.js (NAPI) binding for a unary mathematical function
459///
460/// # Arguments
461///
462/// * `name` - The function name (e.g., `sin`)
463///
464/// # Examples
465///
466/// ```rust,ignore
467/// generate_nodejs_binding!(sin);
468/// ```
469///
470/// Generates:
471///
472/// ```rust,ignore
473/// #[napi]
474/// pub fn sin(x: napi::Either<&JsExpression, f64>) -> JsExpression {
475///     let expr = match x {
476///         napi::Either::A(e) => e.inner.clone(),
477///         napi::Either::B(num) => {
478///             if num.fract() == 0.0 && num.is_finite() {
479///                 Expression::integer(num as i64)
480///             } else {
481///                 Expression::float(num)
482///             }
483///         }
484///     };
485///     JsExpression {
486///         inner: Expression::function("sin", vec![expr]),
487///     }
488/// }
489/// ```
490#[proc_macro]
491pub fn generate_nodejs_binding(input: TokenStream) -> TokenStream {
492    let name = parse_macro_input!(input as Ident);
493    let name_str = name.to_string();
494
495    let expanded = quote! {
496        #[napi_derive::napi]
497        pub fn #name(x: napi::bindgen_prelude::Either<&crate::JsExpression, f64>) -> crate::JsExpression {
498            use mathhook_core::Expression;
499
500            let expr = match x {
501                napi::bindgen_prelude::Either::A(e) => e.inner.clone(),
502                napi::bindgen_prelude::Either::B(num) => {
503                    if num.fract() == 0.0 && num.is_finite() {
504                        Expression::integer(num as i64)
505                    } else {
506                        Expression::float(num)
507                    }
508                }
509            };
510
511            crate::JsExpression {
512                inner: Expression::function(#name_str, vec![expr]),
513            }
514        }
515    };
516
517    TokenStream::from(expanded)
518}
519
520/// Generate Node.js (NAPI) binding for a binary mathematical function
521///
522/// # Arguments
523///
524/// * `name` - The function name (e.g., `pow`)
525///
526/// # Examples
527///
528/// ```rust,ignore
529/// generate_nodejs_binary_binding!(pow);
530/// ```
531///
532/// Generates:
533///
534/// ```rust,ignore
535/// #[napi]
536/// pub fn pow(
537///     x: napi::Either<&JsExpression, f64>,
538///     y: napi::Either<&JsExpression, f64>
539/// ) -> JsExpression {
540///     let expr1 = match x { /* ... */ };
541///     let expr2 = match y { /* ... */ };
542///     JsExpression {
543///         inner: Expression::function("pow", vec![expr1, expr2]),
544///     }
545/// }
546/// ```
547#[proc_macro]
548pub fn generate_nodejs_binary_binding(input: TokenStream) -> TokenStream {
549    let name = parse_macro_input!(input as Ident);
550    let name_str = name.to_string();
551
552    let expanded = quote! {
553        #[napi_derive::napi]
554        pub fn #name(
555            x: napi::bindgen_prelude::Either<&crate::JsExpression, f64>,
556            y: napi::bindgen_prelude::Either<&crate::JsExpression, f64>,
557        ) -> crate::JsExpression {
558            use mathhook_core::Expression;
559
560            let expr1 = match x {
561                napi::bindgen_prelude::Either::A(e) => e.inner.clone(),
562                napi::bindgen_prelude::Either::B(num) => {
563                    if num.fract() == 0.0 && num.is_finite() {
564                        Expression::integer(num as i64)
565                    } else {
566                        Expression::float(num)
567                    }
568                }
569            };
570
571            let expr2 = match y {
572                napi::bindgen_prelude::Either::A(e) => e.inner.clone(),
573                napi::bindgen_prelude::Either::B(num) => {
574                    if num.fract() == 0.0 && num.is_finite() {
575                        Expression::integer(num as i64)
576                    } else {
577                        Expression::float(num)
578                    }
579                }
580            };
581
582            crate::JsExpression {
583                inner: Expression::function(#name_str, vec![expr1, expr2]),
584            }
585        }
586    };
587
588    TokenStream::from(expanded)
589}
590
591/// Generate Node.js (NAPI) binding for a variadic mathematical function
592///
593/// # Arguments
594///
595/// * `name` - The function name (e.g., `add`)
596///
597/// # Examples
598///
599/// ```rust,ignore
600/// generate_nodejs_variadic_binding!(add);
601/// ```
602///
603/// Generates:
604///
605/// ```rust,ignore
606/// #[napi]
607/// pub fn add(args: Vec<napi::Either<&JsExpression, f64>>) -> JsExpression {
608///     let exprs: Vec<Expression> = args
609///         .into_iter()
610///         .map(|x| match x { /* ... */ })
611///         .collect();
612///     JsExpression {
613///         inner: Expression::function("add", exprs),
614///     }
615/// }
616/// ```
617#[proc_macro]
618pub fn generate_nodejs_variadic_binding(input: TokenStream) -> TokenStream {
619    let name = parse_macro_input!(input as Ident);
620    let name_str = name.to_string();
621
622    let expanded = quote! {
623        #[napi_derive::napi]
624        pub fn #name(args: Vec<napi::bindgen_prelude::Either<&crate::JsExpression, f64>>) -> crate::JsExpression {
625            use mathhook_core::Expression;
626
627            let exprs: Vec<Expression> = args
628                .into_iter()
629                .map(|x| match x {
630                    napi::bindgen_prelude::Either::A(e) => e.inner.clone(),
631                    napi::bindgen_prelude::Either::B(num) => {
632                        if num.fract() == 0.0 && num.is_finite() {
633                            Expression::integer(num as i64)
634                        } else {
635                            Expression::float(num)
636                        }
637                    }
638                })
639                .collect();
640
641            crate::JsExpression {
642                inner: Expression::function(#name_str, exprs),
643            }
644        }
645    };
646
647    TokenStream::from(expanded)
648}
649
650/// Generate Node.js (NAPI) binding for a zero-argument constant function
651///
652/// # Arguments
653///
654/// * `name` - The constant name (e.g., `pi`)
655///
656/// # Examples
657///
658/// ```rust,ignore
659/// generate_nodejs_constant_binding!(pi);
660/// ```
661///
662/// Generates:
663///
664/// ```rust,ignore
665/// #[napi]
666/// pub fn pi() -> JsExpression {
667///     JsExpression {
668///         inner: Expression::pi(),
669///     }
670/// }
671/// ```
672#[proc_macro]
673pub fn generate_nodejs_constant_binding(input: TokenStream) -> TokenStream {
674    let name = parse_macro_input!(input as Ident);
675
676    let expanded = quote! {
677        #[napi_derive::napi]
678        pub fn #name() -> crate::JsExpression {
679            use mathhook_core::Expression;
680            crate::JsExpression {
681                inner: Expression::#name(),
682            }
683        }
684    };
685
686    TokenStream::from(expanded)
687}
688
689#[cfg(test)]
690mod tests {
691    #[test]
692    fn test_macros_module_compiles() {
693        println!("All macro definitions compiled successfully");
694    }
695}