mathhook_macros/lib.rs
1//! Code generation macros for MathHook language bindings and expression construction
2//!
3//! This crate provides procedural macros to generate Python (PyO3) and Node.js (NAPI)
4//! bindings from a single function definition, eliminating code duplication while
5//! maintaining zero-cost abstractions.
6//!
7//! Also provides procedural macros for mathematical expression and symbol construction.
8//!
9//! # Supported Function Arities
10//!
11//! - **Unary**: sin(x), cos(x), exp(x)
12//! - **Binary**: pow(x, y), atan2(y, x)
13//! - **Variadic**: add(args...), mul(args...)
14//! - **Constants**: pi(), e()
15
16use proc_macro::TokenStream;
17use quote::quote;
18use syn::{parse_macro_input, Ident};
19
20mod expr;
21mod function;
22mod symbol;
23
24/// Procedural macro for creating mathematical expressions with full syntax support
25///
26/// # Supported Operators
27///
28/// - **Basic arithmetic**: `+`, `-`, `*`, `/`
29/// - **Unary negation**: `-x`
30/// - **Power operations**: Both `**` and `.pow()` syntax
31/// - **Comparison operators**: `==`, `<`, `>`, `<=`, `>=`
32///
33/// # Supported Literals
34///
35/// - **Integers**: `42`, `-5`, `0`
36/// - **Floats**: `3.14`, `-2.5`, `0.0`
37/// - **Identifiers/symbols**: `x`, `theta`, `alpha_1`
38/// - **Mathematical constants**: `pi`, `e`, `i`
39///
40/// # Supported Constructs
41///
42/// - **Function calls**: `sin(x)`, `log(x, y)`, `f(a, b, c)`
43/// - **Parenthesized expressions**: `(2*x + 3)`, `((x+y))`
44/// - **Method calls**: `x.pow(2)`, `x.abs()`, `x.sqrt()`, `x.simplify()`
45///
46/// # Power Operator Syntax
47///
48/// The macro supports **two syntaxes** for exponentiation:
49///
50/// ## Method Syntax: `.pow()`
51///
52/// The `.pow()` method works everywhere without restrictions:
53///
54/// ```rust,ignore
55/// expr!(x.pow(2)) // x squared
56/// expr!(x.pow(y)) // x to the y
57/// expr!((x + 1).pow(3)) // (x+1) cubed
58/// expr!(2 * x.pow(3) + 5) // Complex expression
59/// ```
60///
61/// **Advantages:**
62/// - Mirrors Rust's standard library API
63/// - Works seamlessly in any context
64/// - Clear, unambiguous precedence
65/// - No parentheses required
66///
67/// ## Infix Syntax: `**`
68///
69/// The `**` operator provides mathematical notation:
70///
71/// ```rust,ignore
72/// expr!(x ** 2) // x squared
73/// expr!(x ** y) // x to the y
74/// expr!((x + 1) ** 3) // (x+1) cubed
75/// ```
76///
77/// **IMPORTANT - When to use parentheses with `**`:**
78///
79/// Due to token-level preprocessing, complex expressions with `**` should use
80/// parentheses for clarity when mixing with other operators:
81///
82/// ```rust,ignore
83/// // ✅ RECOMMENDED (clear precedence):
84/// expr!(2 * (x ** 3) + 5) // Parentheses make precedence explicit
85/// expr!((x ** 2) + (y ** 2)) // Clear grouping
86///
87/// // ⚠️ WORKS BUT LESS CLEAR:
88/// expr!(x ** 2 + y ** 2) // Relies on implicit precedence
89/// ```
90///
91/// **Right-associativity:**
92///
93/// Power operations are right-associative (like standard mathematical notation):
94///
95/// ```rust,ignore
96/// expr!(x ** 2 ** 3) // Parsed as: x ** (2 ** 3) = x ** 8
97/// expr!(2 ** 3 ** 2) // Parsed as: 2 ** (3 ** 2) = 2 ** 9 = 512
98/// ```
99///
100/// **Which syntax to use?**
101///
102/// - Use `.pow()` for guaranteed clarity in complex expressions
103/// - Use `**` when writing simple mathematical formulas that match notation
104/// - When in doubt, add parentheses: `expr!(2 * (x ** 3) + 5)`
105///
106/// # Examples
107///
108/// ```rust,ignore
109/// use mathhook_macros::expr;
110/// use mathhook_core::{Expression, symbol};
111///
112/// let x = symbol!(x);
113/// let y = symbol!(y);
114///
115/// // Basic operations
116/// let sum = expr!(x + 2);
117/// let product = expr!(2 * x);
118///
119/// // Power operations (both syntaxes work)
120/// let power1 = expr!(x ** 2);
121/// let power2 = expr!(x.pow(2));
122///
123/// // Complex expressions (use parentheses with ** for clarity)
124/// let quadratic = expr!((x ** 2) + 2*x + 1);
125/// let nested = expr!((x + 1) * (x - 1));
126/// let functions = expr!(sin(x ** 2));
127///
128/// // Comparison operators
129/// let equation = expr!(x ** 2 == 4);
130/// let inequality = expr!(x + 1 > y);
131///
132/// // Method calls
133/// let abs_val = expr!(x.abs());
134/// let sqrt_expr = expr!(x.sqrt());
135/// let simplified = expr!((x + x).simplify());
136/// ```
137///
138/// # Precedence Rules
139///
140/// Operator precedence from highest to lowest:
141///
142/// 1. **Method calls** (highest): `.pow()`, `.abs()`, `.sqrt()`, `.simplify()`
143/// 2. **Power** (right-associative): `**`
144/// 3. **Unary negation**: `-x`
145/// 4. **Multiplication/division**: `*`, `/`
146/// 5. **Addition/subtraction**: `+`, `-`
147/// 6. **Comparison operators** (lowest): `==`, `<`, `>`, `<=`, `>=`
148///
149/// Use parentheses to override precedence:
150///
151/// ```rust,ignore
152/// expr!(2 * x + 3) // Parsed as: (2*x) + 3
153/// expr!((2 + 3) * x) // Parentheses override: 5*x
154/// expr!(2 * (x ** 3)) // Recommended for clarity with **
155/// expr!((x + 1) ** 2) // Power of sum
156/// ```
157#[proc_macro]
158pub fn expr(input: TokenStream) -> TokenStream {
159 expr::expr_impl(input)
160}
161
162/// Procedural macro for creating symbols with optional type specification
163///
164/// # Syntax
165///
166/// ```rust,ignore
167/// symbol!(x) // Scalar (default, commutative)
168/// symbol!(A; matrix) // Matrix (noncommutative)
169/// symbol!(p; operator) // Operator (noncommutative)
170/// symbol!(i; quaternion) // Quaternion (noncommutative)
171/// symbol!("name") // String literal for symbol name
172/// ```
173///
174/// # Symbol Types
175///
176/// - **scalar**: Commutative symbols (default) - variables like x, y, z
177/// - **matrix**: Noncommutative matrix symbols - A*B ≠ B*A
178/// - **operator**: Noncommutative operator symbols - for quantum mechanics [x,p] ≠ 0
179/// - **quaternion**: Noncommutative quaternion symbols - i*j = k, j*i = -k
180///
181/// # Examples
182///
183/// ```rust,ignore
184/// use mathhook_macros::symbol;
185///
186/// // Scalar symbols (commutative)
187/// let x = symbol!(x);
188/// let theta = symbol!(theta);
189///
190/// // Matrix symbols (noncommutative)
191/// let A = symbol!(A; matrix);
192/// let B = symbol!(B; matrix);
193///
194/// // Operator symbols (noncommutative)
195/// let p = symbol!(p; operator);
196/// let x_op = symbol!(x; operator);
197///
198/// // Quaternion symbols (noncommutative)
199/// let i = symbol!(i; quaternion);
200/// let j = symbol!(j; quaternion);
201/// ```
202#[proc_macro]
203pub fn symbol(input: TokenStream) -> TokenStream {
204 symbol::symbol_impl(input)
205}
206
207/// Procedural macro for creating multiple symbols at once
208///
209/// # Syntax
210///
211/// ```rust,ignore
212/// symbols![x, y, z] // All scalars (default)
213/// symbols![A, B, C => matrix] // All matrices
214/// symbols![p, x, H => operator] // All operators
215/// symbols![i, j, k => quaternion] // All quaternions
216/// ```
217///
218/// # Returns
219///
220/// Returns `Vec<Symbol>` containing all created symbols.
221///
222/// # Examples
223///
224/// ```rust,ignore
225/// use mathhook_macros::symbols;
226///
227/// // Scalar symbols (default, commutative)
228/// let syms = symbols![x, y, z];
229/// assert_eq!(syms.len(), 3);
230///
231/// // Matrix symbols (noncommutative)
232/// let mats = symbols![A, B, C => matrix];
233/// assert_eq!(mats.len(), 3);
234///
235/// // Operator symbols (noncommutative)
236/// let ops = symbols![p, x, H => operator];
237/// assert_eq!(ops.len(), 3);
238///
239/// // Quaternion symbols (noncommutative)
240/// let quats = symbols![i, j, k => quaternion];
241/// assert_eq!(quats.len(), 3);
242/// ```
243#[proc_macro]
244pub fn symbols(input: TokenStream) -> TokenStream {
245 symbol::symbols_impl(input)
246}
247
248/// Procedural macro for creating function expressions
249///
250/// # Syntax
251///
252/// ```rust,ignore
253/// function!(sin) // Zero args
254/// function!(sin, x) // One arg (x is an Expression)
255/// function!(log, x, y) // Two args
256/// function!(f, a, b, c) // N args
257/// ```
258///
259/// # Examples
260///
261/// ```rust,ignore
262/// use mathhook_macros::function;
263/// use mathhook_core::expr;
264///
265/// // Zero-argument function
266/// let gamma_call = function!(gamma);
267///
268/// // Single argument
269/// let x = expr!(x);
270/// let sin_x = function!(sin, x);
271///
272/// // Multiple arguments
273/// let log_xy = function!(log, x, expr!(2));
274/// ```
275#[proc_macro]
276pub fn function(input: TokenStream) -> TokenStream {
277 function::function_impl(input)
278}
279
280/// Generate Python (PyO3) binding for a unary mathematical function
281///
282/// # Arguments
283///
284/// * `name` - The function name (e.g., `sin`)
285///
286/// # Examples
287///
288/// ```rust,ignore
289/// generate_python_binding!(sin);
290/// ```
291///
292/// Generates:
293///
294/// ```rust,ignore
295/// #[pyfunction]
296/// pub fn sin(x: &Bound<'_, PyAny>) -> PyResult<PyExpression> {
297/// let expr = sympify_python(x)?;
298/// Ok(PyExpression {
299/// inner: Expression::function("sin", vec![expr]),
300/// })
301/// }
302/// ```
303#[proc_macro]
304pub fn generate_python_binding(input: TokenStream) -> TokenStream {
305 let name = parse_macro_input!(input as Ident);
306 let name_str = name.to_string();
307
308 let expanded = quote! {
309 #[pyo3::pyfunction]
310 pub fn #name(x: &pyo3::Bound<'_, pyo3::PyAny>) -> pyo3::PyResult<crate::PyExpression> {
311 use mathhook_core::Expression;
312 let expr = crate::helpers::sympify_python(x)?;
313 Ok(crate::PyExpression {
314 inner: Expression::function(#name_str, vec![expr]),
315 })
316 }
317 };
318
319 TokenStream::from(expanded)
320}
321
322/// Generate Python (PyO3) binding for a binary mathematical function
323///
324/// # Arguments
325///
326/// * `name` - The function name (e.g., `pow`)
327///
328/// # Examples
329///
330/// ```rust,ignore
331/// generate_python_binary_binding!(pow);
332/// ```
333///
334/// Generates:
335///
336/// ```rust,ignore
337/// #[pyfunction]
338/// pub fn pow(x: &Bound<'_, PyAny>, y: &Bound<'_, PyAny>) -> PyResult<PyExpression> {
339/// let expr1 = sympify_python(x)?;
340/// let expr2 = sympify_python(y)?;
341/// Ok(PyExpression {
342/// inner: Expression::function("pow", vec![expr1, expr2]),
343/// })
344/// }
345/// ```
346#[proc_macro]
347pub fn generate_python_binary_binding(input: TokenStream) -> TokenStream {
348 let name = parse_macro_input!(input as Ident);
349 let name_str = name.to_string();
350
351 let expanded = quote! {
352 #[pyo3::pyfunction]
353 pub fn #name(
354 x: &pyo3::Bound<'_, pyo3::PyAny>,
355 y: &pyo3::Bound<'_, pyo3::PyAny>,
356 ) -> pyo3::PyResult<crate::PyExpression> {
357 use mathhook_core::Expression;
358 let expr1 = crate::helpers::sympify_python(x)?;
359 let expr2 = crate::helpers::sympify_python(y)?;
360 Ok(crate::PyExpression {
361 inner: Expression::function(#name_str, vec![expr1, expr2]),
362 })
363 }
364 };
365
366 TokenStream::from(expanded)
367}
368
369/// Generate Python (PyO3) binding for a variadic mathematical function
370///
371/// # Arguments
372///
373/// * `name` - The function name (e.g., `add`)
374///
375/// # Examples
376///
377/// ```rust,ignore
378/// generate_python_variadic_binding!(add);
379/// ```
380///
381/// Generates:
382///
383/// ```rust,ignore
384/// #[pyfunction]
385/// #[pyo3(signature = (*args))]
386/// pub fn add(args: &Bound<'_, PyTuple>) -> PyResult<PyExpression> {
387/// let mut exprs = Vec::new();
388/// for arg in args {
389/// exprs.push(sympify_python(&arg)?);
390/// }
391/// Ok(PyExpression {
392/// inner: Expression::function("add", exprs),
393/// })
394/// }
395/// ```
396#[proc_macro]
397pub fn generate_python_variadic_binding(input: TokenStream) -> TokenStream {
398 let name = parse_macro_input!(input as Ident);
399 let name_str = name.to_string();
400
401 let expanded = quote! {
402 #[pyo3::pyfunction]
403 #[pyo3(signature = (*args))]
404 pub fn #name(args: &pyo3::Bound<'_, pyo3::types::PyTuple>) -> pyo3::PyResult<crate::PyExpression> {
405 use mathhook_core::Expression;
406 let mut exprs = Vec::new();
407 for arg in args {
408 exprs.push(crate::helpers::sympify_python(&arg)?);
409 }
410 Ok(crate::PyExpression {
411 inner: Expression::function(#name_str, exprs),
412 })
413 }
414 };
415
416 TokenStream::from(expanded)
417}
418
419/// Generate Python (PyO3) binding for a zero-argument constant function
420///
421/// # Arguments
422///
423/// * `name` - The constant name (e.g., `pi`)
424///
425/// # Examples
426///
427/// ```rust,ignore
428/// generate_python_constant_binding!(pi);
429/// ```
430///
431/// Generates:
432///
433/// ```rust,ignore
434/// #[pyfunction]
435/// pub fn pi() -> PyExpression {
436/// PyExpression {
437/// inner: Expression::pi(),
438/// }
439/// }
440/// ```
441#[proc_macro]
442pub fn generate_python_constant_binding(input: TokenStream) -> TokenStream {
443 let name = parse_macro_input!(input as Ident);
444
445 let expanded = quote! {
446 #[pyo3::pyfunction]
447 pub fn #name() -> crate::PyExpression {
448 use mathhook_core::Expression;
449 crate::PyExpression {
450 inner: Expression::#name(),
451 }
452 }
453 };
454
455 TokenStream::from(expanded)
456}
457
458/// Generate Node.js (NAPI) binding for a unary mathematical function
459///
460/// # Arguments
461///
462/// * `name` - The function name (e.g., `sin`)
463///
464/// # Examples
465///
466/// ```rust,ignore
467/// generate_nodejs_binding!(sin);
468/// ```
469///
470/// Generates:
471///
472/// ```rust,ignore
473/// #[napi]
474/// pub fn sin(x: napi::Either<&JsExpression, f64>) -> JsExpression {
475/// let expr = match x {
476/// napi::Either::A(e) => e.inner.clone(),
477/// napi::Either::B(num) => {
478/// if num.fract() == 0.0 && num.is_finite() {
479/// Expression::integer(num as i64)
480/// } else {
481/// Expression::float(num)
482/// }
483/// }
484/// };
485/// JsExpression {
486/// inner: Expression::function("sin", vec![expr]),
487/// }
488/// }
489/// ```
490#[proc_macro]
491pub fn generate_nodejs_binding(input: TokenStream) -> TokenStream {
492 let name = parse_macro_input!(input as Ident);
493 let name_str = name.to_string();
494
495 let expanded = quote! {
496 #[napi_derive::napi]
497 pub fn #name(x: napi::bindgen_prelude::Either<&crate::JsExpression, f64>) -> crate::JsExpression {
498 use mathhook_core::Expression;
499
500 let expr = match x {
501 napi::bindgen_prelude::Either::A(e) => e.inner.clone(),
502 napi::bindgen_prelude::Either::B(num) => {
503 if num.fract() == 0.0 && num.is_finite() {
504 Expression::integer(num as i64)
505 } else {
506 Expression::float(num)
507 }
508 }
509 };
510
511 crate::JsExpression {
512 inner: Expression::function(#name_str, vec![expr]),
513 }
514 }
515 };
516
517 TokenStream::from(expanded)
518}
519
520/// Generate Node.js (NAPI) binding for a binary mathematical function
521///
522/// # Arguments
523///
524/// * `name` - The function name (e.g., `pow`)
525///
526/// # Examples
527///
528/// ```rust,ignore
529/// generate_nodejs_binary_binding!(pow);
530/// ```
531///
532/// Generates:
533///
534/// ```rust,ignore
535/// #[napi]
536/// pub fn pow(
537/// x: napi::Either<&JsExpression, f64>,
538/// y: napi::Either<&JsExpression, f64>
539/// ) -> JsExpression {
540/// let expr1 = match x { /* ... */ };
541/// let expr2 = match y { /* ... */ };
542/// JsExpression {
543/// inner: Expression::function("pow", vec![expr1, expr2]),
544/// }
545/// }
546/// ```
547#[proc_macro]
548pub fn generate_nodejs_binary_binding(input: TokenStream) -> TokenStream {
549 let name = parse_macro_input!(input as Ident);
550 let name_str = name.to_string();
551
552 let expanded = quote! {
553 #[napi_derive::napi]
554 pub fn #name(
555 x: napi::bindgen_prelude::Either<&crate::JsExpression, f64>,
556 y: napi::bindgen_prelude::Either<&crate::JsExpression, f64>,
557 ) -> crate::JsExpression {
558 use mathhook_core::Expression;
559
560 let expr1 = match x {
561 napi::bindgen_prelude::Either::A(e) => e.inner.clone(),
562 napi::bindgen_prelude::Either::B(num) => {
563 if num.fract() == 0.0 && num.is_finite() {
564 Expression::integer(num as i64)
565 } else {
566 Expression::float(num)
567 }
568 }
569 };
570
571 let expr2 = match y {
572 napi::bindgen_prelude::Either::A(e) => e.inner.clone(),
573 napi::bindgen_prelude::Either::B(num) => {
574 if num.fract() == 0.0 && num.is_finite() {
575 Expression::integer(num as i64)
576 } else {
577 Expression::float(num)
578 }
579 }
580 };
581
582 crate::JsExpression {
583 inner: Expression::function(#name_str, vec![expr1, expr2]),
584 }
585 }
586 };
587
588 TokenStream::from(expanded)
589}
590
591/// Generate Node.js (NAPI) binding for a variadic mathematical function
592///
593/// # Arguments
594///
595/// * `name` - The function name (e.g., `add`)
596///
597/// # Examples
598///
599/// ```rust,ignore
600/// generate_nodejs_variadic_binding!(add);
601/// ```
602///
603/// Generates:
604///
605/// ```rust,ignore
606/// #[napi]
607/// pub fn add(args: Vec<napi::Either<&JsExpression, f64>>) -> JsExpression {
608/// let exprs: Vec<Expression> = args
609/// .into_iter()
610/// .map(|x| match x { /* ... */ })
611/// .collect();
612/// JsExpression {
613/// inner: Expression::function("add", exprs),
614/// }
615/// }
616/// ```
617#[proc_macro]
618pub fn generate_nodejs_variadic_binding(input: TokenStream) -> TokenStream {
619 let name = parse_macro_input!(input as Ident);
620 let name_str = name.to_string();
621
622 let expanded = quote! {
623 #[napi_derive::napi]
624 pub fn #name(args: Vec<napi::bindgen_prelude::Either<&crate::JsExpression, f64>>) -> crate::JsExpression {
625 use mathhook_core::Expression;
626
627 let exprs: Vec<Expression> = args
628 .into_iter()
629 .map(|x| match x {
630 napi::bindgen_prelude::Either::A(e) => e.inner.clone(),
631 napi::bindgen_prelude::Either::B(num) => {
632 if num.fract() == 0.0 && num.is_finite() {
633 Expression::integer(num as i64)
634 } else {
635 Expression::float(num)
636 }
637 }
638 })
639 .collect();
640
641 crate::JsExpression {
642 inner: Expression::function(#name_str, exprs),
643 }
644 }
645 };
646
647 TokenStream::from(expanded)
648}
649
650/// Generate Node.js (NAPI) binding for a zero-argument constant function
651///
652/// # Arguments
653///
654/// * `name` - The constant name (e.g., `pi`)
655///
656/// # Examples
657///
658/// ```rust,ignore
659/// generate_nodejs_constant_binding!(pi);
660/// ```
661///
662/// Generates:
663///
664/// ```rust,ignore
665/// #[napi]
666/// pub fn pi() -> JsExpression {
667/// JsExpression {
668/// inner: Expression::pi(),
669/// }
670/// }
671/// ```
672#[proc_macro]
673pub fn generate_nodejs_constant_binding(input: TokenStream) -> TokenStream {
674 let name = parse_macro_input!(input as Ident);
675
676 let expanded = quote! {
677 #[napi_derive::napi]
678 pub fn #name() -> crate::JsExpression {
679 use mathhook_core::Expression;
680 crate::JsExpression {
681 inner: Expression::#name(),
682 }
683 }
684 };
685
686 TokenStream::from(expanded)
687}
688
689#[cfg(test)]
690mod tests {
691 #[test]
692 fn test_macros_module_compiles() {
693 println!("All macro definitions compiled successfully");
694 }
695}