mathhook_macros/lib.rs
1//! Code generation macros for MathHook language bindings and expression construction
2//!
3//! This crate provides procedural macros to generate Python (PyO3) and Node.js (NAPI)
4//! bindings from a single function definition, eliminating code duplication while
5//! maintaining zero-cost abstractions.
6//!
7//! Also provides procedural macros for mathematical expression and symbol construction.
8//!
9//! # Supported Function Arities
10//!
11//! - **Unary**: sin(x), cos(x), exp(x)
12//! - **Binary**: pow(x, y), atan2(y, x)
13//! - **Variadic**: add(args...), mul(args...)
14//! - **Constants**: pi(), e()
15
16use proc_macro::TokenStream;
17use quote::quote;
18use syn::{parse_macro_input, Ident};
19
20mod expr;
21mod function;
22mod symbol;
23
24/// Procedural macro for creating mathematical expressions with full syntax support
25///
26/// # Supported Operators
27///
28/// - **Basic arithmetic**: `+`, `-`, `*`, `/`
29/// - **Unary negation**: `-x`
30/// - **Power operations**: Both `**` and `.pow()` syntax
31/// - **Comparison operators**: `==`, `<`, `>`, `<=`, `>=`
32///
33/// # Supported Literals
34///
35/// - **Integers**: `42`, `-5`, `0`
36/// - **Floats**: `3.14`, `-2.5`, `0.0`
37/// - **Identifiers/symbols**: `x`, `theta`, `alpha_1`
38/// - **Mathematical constants**: `pi`, `e`, `i`
39///
40/// # Supported Constructs
41///
42/// - **Function calls**: `sin(x)`, `log(x, y)`, `f(a, b, c)`
43/// - **Parenthesized expressions**: `(2*x + 3)`, `((x+y))`
44/// - **Method calls**: `x.pow(2)`, `x.abs()`, `x.sqrt()`, `x.simplify()`
45///
46/// # Power Operator Syntax
47///
48/// The macro supports **two syntaxes** for exponentiation:
49///
50/// ## Method Syntax: `.pow()`
51///
52/// The `.pow()` method works everywhere without restrictions:
53///
54/// ```rust,ignore
55/// expr!(x.pow(2)) // x squared
56/// expr!(x.pow(y)) // x to the y
57/// expr!((x + 1).pow(3)) // (x+1) cubed
58/// expr!(2 * x.pow(3) + 5) // Complex expression
59/// ```
60///
61/// **Advantages:**
62/// - Mirrors Rust's standard library API
63/// - Works seamlessly in any context
64/// - Clear, unambiguous precedence
65/// - No parentheses required
66///
67/// ## Infix Syntax: `**`
68///
69/// The `**` operator provides mathematical notation:
70///
71/// ```rust,ignore
72/// expr!(x ** 2) // x squared
73/// expr!(x ** y) // x to the y
74/// expr!((x + 1) ** 3) // (x+1) cubed
75/// ```
76///
77/// **IMPORTANT - When to use parentheses with `**`:**
78///
79/// Due to token-level preprocessing, complex expressions with `**` should use
80/// parentheses for clarity when mixing with other operators:
81///
82/// ```rust,ignore
83/// // ✅ RECOMMENDED (clear precedence):
84/// expr!(2 * (x ** 3) + 5) // Parentheses make precedence explicit
85/// expr!((x ** 2) + (y ** 2)) // Clear grouping
86///
87/// // ⚠️ WORKS BUT LESS CLEAR:
88/// expr!(x ** 2 + y ** 2) // Relies on implicit precedence
89/// ```
90///
91/// **Right-associativity:**
92///
93/// Power operations are right-associative (like standard mathematical notation):
94///
95/// ```rust,ignore
96/// expr!(x ** 2 ** 3) // Parsed as: x ** (2 ** 3) = x ** 8
97/// expr!(2 ** 3 ** 2) // Parsed as: 2 ** (3 ** 2) = 2 ** 9 = 512
98/// ```
99///
100/// **Which syntax to use?**
101///
102/// - Use `.pow()` for guaranteed clarity in complex expressions
103/// - Use `**` when writing simple mathematical formulas that match notation
104/// - When in doubt, add parentheses: `expr!(2 * (x ** 3) + 5)`
105///
106/// # Examples
107///
108/// ```rust,ignore
109/// use mathhook_macros::expr;
110/// use mathhook_core::{Expression, symbol};
111///
112/// let x = symbol!(x);
113/// let y = symbol!(y);
114///
115/// // Basic operations
116/// let sum = expr!(x + 2);
117/// let product = expr!(2 * x);
118///
119/// // Power operations (both syntaxes work)
120/// let power1 = expr!(x ** 2);
121/// let power2 = expr!(x.pow(2));
122///
123/// // Complex expressions (use parentheses with ** for clarity)
124/// let quadratic = expr!((x ** 2) + 2*x + 1);
125/// let nested = expr!((x + 1) * (x - 1));
126/// let functions = expr!(sin(x ** 2));
127///
128/// // Comparison operators
129/// let equation = expr!(x ** 2 == 4);
130/// let inequality = expr!(x + 1 > y);
131///
132/// // Method calls
133/// let abs_val = expr!(x.abs());
134/// let sqrt_expr = expr!(x.sqrt());
135/// let simplified = expr!((x + x).simplify());
136/// ```
137///
138/// # Precedence Rules
139///
140/// Operator precedence from highest to lowest:
141///
142/// 1. **Method calls** (highest): `.pow()`, `.abs()`, `.sqrt()`, `.simplify()`
143/// 2. **Power** (right-associative): `**`
144/// 3. **Unary negation**: `-x`
145/// 4. **Multiplication/division**: `*`, `/`
146/// 5. **Addition/subtraction**: `+`, `-`
147/// 6. **Comparison operators** (lowest): `==`, `<`, `>`, `<=`, `>=`
148///
149/// Use parentheses to override precedence:
150///
151/// ```rust,ignore
152/// expr!(2 * x + 3) // Parsed as: (2*x) + 3
153/// expr!((2 + 3) * x) // Parentheses override: 5*x
154/// expr!(2 * (x ** 3)) // Recommended for clarity with **
155/// expr!((x + 1) ** 2) // Power of sum
156/// ```
157#[proc_macro]
158pub fn expr(input: TokenStream) -> TokenStream {
159 expr::expr_impl(input)
160}
161
162/// Procedural macro for creating symbols with optional type specification
163///
164/// # Syntax
165///
166/// ```rust,ignore
167/// symbol!(x) // Scalar (default, commutative)
168/// symbol!(A; matrix) // Matrix (noncommutative)
169/// symbol!(p; operator) // Operator (noncommutative)
170/// symbol!(i; quaternion) // Quaternion (noncommutative)
171/// symbol!("name") // String literal for symbol name
172/// ```
173///
174/// # Symbol Types
175///
176/// - **scalar**: Commutative symbols (default) - variables like x, y, z
177/// - **matrix**: Noncommutative matrix symbols - A*B ≠ B*A
178/// - **operator**: Noncommutative operator symbols - for quantum mechanics [x,p] ≠ 0
179/// - **quaternion**: Noncommutative quaternion symbols - i*j = k, j*i = -k
180///
181/// # Examples
182///
183/// ```rust,ignore
184/// use mathhook_macros::symbol;
185///
186/// // Scalar symbols (commutative)
187/// let x = symbol!(x);
188/// let theta = symbol!(theta);
189///
190/// // Matrix symbols (noncommutative)
191/// let A = symbol!(A; matrix);
192/// let B = symbol!(B; matrix);
193///
194/// // Operator symbols (noncommutative)
195/// let p = symbol!(p; operator);
196/// let x_op = symbol!(x; operator);
197///
198/// // Quaternion symbols (noncommutative)
199/// let i = symbol!(i; quaternion);
200/// let j = symbol!(j; quaternion);
201/// ```
202#[proc_macro]
203pub fn symbol(input: TokenStream) -> TokenStream {
204 symbol::symbol_impl(input)
205}
206
207/// Procedural macro for creating multiple symbols at once
208///
209/// # Syntax
210///
211/// ```rust,ignore
212/// symbols![x, y, z] // All scalars (default)
213/// symbols![A, B, C => matrix] // All matrices
214/// symbols![p, x, H => operator] // All operators
215/// symbols![i, j, k => quaternion] // All quaternions
216/// ```
217///
218/// # Returns
219///
220/// Returns `Vec<Symbol>` containing all created symbols.
221///
222/// # Examples
223///
224/// ```rust,ignore
225/// use mathhook_macros::symbols;
226///
227/// // Scalar symbols (default, commutative)
228/// let syms = symbols![x, y, z];
229/// assert_eq!(syms.len(), 3);
230///
231/// // Matrix symbols (noncommutative)
232/// let mats = symbols![A, B, C => matrix];
233/// assert_eq!(mats.len(), 3);
234///
235/// // Operator symbols (noncommutative)
236/// let ops = symbols![p, x, H => operator];
237/// assert_eq!(ops.len(), 3);
238///
239/// // Quaternion symbols (noncommutative)
240/// let quats = symbols![i, j, k => quaternion];
241/// assert_eq!(quats.len(), 3);
242/// ```
243#[proc_macro]
244pub fn symbols(input: TokenStream) -> TokenStream {
245 symbol::symbols_impl(input)
246}
247
248/// Procedural macro for creating function expressions
249///
250/// # Syntax
251///
252/// ```rust,ignore
253/// function!(sin) // Zero args
254/// function!(sin, x) // One arg (x is an Expression)
255/// function!(log, x, y) // Two args
256/// function!(f, a, b, c) // N args
257/// ```
258///
259/// # Examples
260///
261/// ```rust,ignore
262/// use mathhook_macros::function;
263/// use mathhook_core::expr;
264///
265/// // Zero-argument function
266/// let gamma_call = function!(gamma);
267///
268/// // Single argument
269/// let x = expr!(x);
270/// let sin_x = function!(sin, x);
271///
272/// // Multiple arguments
273/// let log_xy = function!(log, x, expr!(2));
274/// ```
275#[proc_macro]
276pub fn function(input: TokenStream) -> TokenStream {
277 function::function_impl(input)
278}
279
280/// Generate Python (PyO3) binding for a unary mathematical function
281///
282/// # Arguments
283///
284/// * `name` - The function name (e.g., `sin`)
285///
286/// # Examples
287///
288/// ```rust,ignore
289/// generate_python_binding!(sin);
290/// ```
291///
292/// Generates:
293///
294/// ```rust,ignore
295/// #[pyfunction]
296/// pub fn sin(x: &Bound<'_, PyAny>) -> PyResult<PyExpression> {
297/// let expr = sympify_python(x)?;
298/// Ok(PyExpression {
299/// inner: Expression::function("sin", vec![expr]),
300/// })
301/// }
302/// ```
303#[proc_macro]
304pub fn generate_python_binding(input: TokenStream) -> TokenStream {
305 let name = parse_macro_input!(input as Ident);
306 let name_str = name.to_string();
307
308 let expanded = quote! {
309 #[::pyo3::pyfunction]
310 pub fn #name(x: &pyo3::Bound<'_, pyo3::PyAny>) -> pyo3::PyResult<crate::PyExpression> {
311 use mathhook_core::Expression;
312 let expr = crate::helpers::sympify_python(x)?;
313 Ok(crate::PyExpression {
314 inner: Expression::function(#name_str, vec![expr]),
315 })
316 }
317 };
318
319 TokenStream::from(expanded)
320}
321
322/// Generate Python (PyO3) binding for a binary mathematical function
323///
324/// # Arguments
325///
326/// * `name` - The function name (e.g., `pow`)
327///
328/// # Examples
329///
330/// ```rust,ignore
331/// generate_python_binary_binding!(pow);
332/// ```
333///
334/// Generates:
335///
336/// ```rust,ignore
337/// #[pyfunction]
338/// pub fn pow(x: &Bound<'_, PyAny>, y: &Bound<'_, PyAny>) -> PyResult<PyExpression> {
339/// let expr1 = sympify_python(x)?;
340/// let expr2 = sympify_python(y)?;
341/// Ok(PyExpression {
342/// inner: Expression::function("pow", vec![expr1, expr2]),
343/// })
344/// }
345/// ```
346#[proc_macro]
347pub fn generate_python_binary_binding(input: TokenStream) -> TokenStream {
348 let name = parse_macro_input!(input as Ident);
349 let name_str = name.to_string();
350
351 let expanded = quote! {
352 #[::pyo3::pyfunction]
353 pub fn #name(
354 x: &pyo3::Bound<'_, pyo3::PyAny>,
355 y: &pyo3::Bound<'_, pyo3::PyAny>,
356 ) -> pyo3::PyResult<crate::PyExpression> {
357 use mathhook_core::Expression;
358 let expr1 = crate::helpers::sympify_python(x)?;
359 let expr2 = crate::helpers::sympify_python(y)?;
360 Ok(crate::PyExpression {
361 inner: Expression::function(#name_str, vec![expr1, expr2]),
362 })
363 }
364 };
365
366 TokenStream::from(expanded)
367}
368
369/// Generate Python (PyO3) binding for a variadic mathematical function
370///
371/// # Arguments
372///
373/// * `name` - The function name (e.g., `add`)
374///
375/// # Examples
376///
377/// ```rust,ignore
378/// generate_python_variadic_binding!(add);
379/// ```
380///
381/// Generates:
382///
383/// ```rust,ignore
384/// #[pyfunction]
385/// #[pyo3(signature = (*args))]
386/// pub fn add(args: &Bound<'_, PyTuple>) -> PyResult<PyExpression> {
387/// let mut exprs = Vec::new();
388/// for arg in args {
389/// exprs.push(sympify_python(&arg)?);
390/// }
391/// Ok(PyExpression {
392/// inner: Expression::function("add", exprs),
393/// })
394/// }
395/// ```
396#[proc_macro]
397pub fn generate_python_variadic_binding(input: TokenStream) -> TokenStream {
398 let name = parse_macro_input!(input as Ident);
399 let name_str = name.to_string();
400
401 let expanded = quote! {
402 #[::pyo3::pyfunction]
403 #[pyo3(signature = (*args))]
404 pub fn #name(args: &pyo3::Bound<'_, pyo3::types::PyTuple>) -> pyo3::PyResult<crate::PyExpression> {
405 use mathhook_core::Expression;
406 let mut exprs = Vec::new();
407 for arg in args {
408 exprs.push(crate::helpers::sympify_python(&arg)?);
409 }
410 Ok(crate::PyExpression {
411 inner: Expression::function(#name_str, exprs),
412 })
413 }
414 };
415
416 TokenStream::from(expanded)
417}
418
419/// Generate Python (PyO3) binding for a zero-argument constant function
420///
421/// # Arguments
422///
423/// * `name` - The constant name (e.g., `pi`)
424///
425/// # Examples
426///
427/// ```rust,ignore
428/// generate_python_constant_binding!(pi);
429/// ```
430///
431/// Generates:
432///
433/// ```rust,ignore
434/// #[pyfunction]
435/// pub fn pi() -> PyExpression {
436/// PyExpression {
437/// inner: Expression::pi(),
438/// }
439/// }
440/// ```
441#[proc_macro]
442pub fn generate_python_constant_binding(input: TokenStream) -> TokenStream {
443 let name = parse_macro_input!(input as Ident);
444
445 let expanded = quote! {
446 #[::pyo3::pyfunction]
447 pub fn #name() -> crate::PyExpression {
448 use mathhook_core::Expression;
449 crate::PyExpression {
450 inner: Expression::#name(),
451 }
452 }
453 };
454
455 TokenStream::from(expanded)
456}
457
458/// Generate Node.js (NAPI) binding for a unary mathematical function
459///
460/// # Arguments
461///
462/// * `name` - The function name (e.g., `sin`)
463///
464/// # Examples
465///
466/// ```rust,ignore
467/// generate_nodejs_binding!(sin);
468/// ```
469///
470/// Generates:
471///
472/// ```rust,ignore
473/// #[napi]
474/// pub fn sin(x: ExpressionOrNumber) -> JsExpression {
475/// JsExpression {
476/// inner: Expression::function("sin", vec![x.0]),
477/// }
478/// }
479/// ```
480///
481/// This accepts both Expression objects and numbers: `sin(x)` or `sin(1.5)`
482#[proc_macro]
483pub fn generate_nodejs_binding(input: TokenStream) -> TokenStream {
484 let name = parse_macro_input!(input as Ident);
485 let name_str = name.to_string();
486
487 let expanded = quote! {
488 #[napi_derive::napi]
489 pub fn #name(x: crate::functions::ExpressionOrNumber) -> crate::JsExpression {
490 use mathhook_core::Expression;
491
492 crate::JsExpression {
493 inner: Expression::function(#name_str, vec![x.0]),
494 }
495 }
496 };
497
498 TokenStream::from(expanded)
499}
500
501/// Generate Node.js (NAPI) binding for a binary mathematical function
502///
503/// # Arguments
504///
505/// * `name` - The function name (e.g., `pow`)
506///
507/// # Examples
508///
509/// ```rust,ignore
510/// generate_nodejs_binary_binding!(pow);
511/// ```
512///
513/// Generates:
514///
515/// ```rust,ignore
516/// #[napi]
517/// pub fn pow(x: ExpressionOrNumber, y: ExpressionOrNumber) -> JsExpression {
518/// JsExpression {
519/// inner: Expression::function("pow", vec![x.0, y.0]),
520/// }
521/// }
522/// ```
523///
524/// This accepts both Expression objects and numbers: `pow(x, 2)` or `pow(2, 3)`
525#[proc_macro]
526pub fn generate_nodejs_binary_binding(input: TokenStream) -> TokenStream {
527 let name = parse_macro_input!(input as Ident);
528 let name_str = name.to_string();
529
530 let expanded = quote! {
531 #[napi_derive::napi]
532 pub fn #name(
533 x: crate::functions::ExpressionOrNumber,
534 y: crate::functions::ExpressionOrNumber,
535 ) -> crate::JsExpression {
536 use mathhook_core::Expression;
537
538 crate::JsExpression {
539 inner: Expression::function(#name_str, vec![x.0, y.0]),
540 }
541 }
542 };
543
544 TokenStream::from(expanded)
545}
546
547/// Generate Node.js (NAPI) binding for a variadic mathematical function
548///
549/// # Arguments
550///
551/// * `name` - The function name (e.g., `add`)
552///
553/// # Examples
554///
555/// ```rust,ignore
556/// generate_nodejs_variadic_binding!(add);
557/// ```
558///
559/// Generates:
560///
561/// ```rust,ignore
562/// #[napi]
563/// pub fn add(args: Vec<napi::Either<&JsExpression, f64>>) -> JsExpression {
564/// let exprs: Vec<Expression> = args
565/// .into_iter()
566/// .map(|x| match x { /* ... */ })
567/// .collect();
568/// JsExpression {
569/// inner: Expression::function("add", exprs),
570/// }
571/// }
572/// ```
573#[proc_macro]
574pub fn generate_nodejs_variadic_binding(input: TokenStream) -> TokenStream {
575 let name = parse_macro_input!(input as Ident);
576 let name_str = name.to_string();
577
578 let expanded = quote! {
579 #[napi_derive::napi]
580 pub fn #name(args: Vec<napi::bindgen_prelude::Either<&crate::JsExpression, f64>>) -> crate::JsExpression {
581 use mathhook_core::Expression;
582
583 let exprs: Vec<Expression> = args
584 .into_iter()
585 .map(|x| match x {
586 napi::bindgen_prelude::Either::A(e) => e.inner.clone(),
587 napi::bindgen_prelude::Either::B(num) => {
588 if num.fract() == 0.0 && num.is_finite() {
589 Expression::integer(num as i64)
590 } else {
591 Expression::float(num)
592 }
593 }
594 })
595 .collect();
596
597 crate::JsExpression {
598 inner: Expression::function(#name_str, exprs),
599 }
600 }
601 };
602
603 TokenStream::from(expanded)
604}
605
606/// Generate Node.js (NAPI) binding for a zero-argument constant function
607///
608/// # Arguments
609///
610/// * `name` - The constant name (e.g., `pi`)
611///
612/// # Examples
613///
614/// ```rust,ignore
615/// generate_nodejs_constant_binding!(pi);
616/// ```
617///
618/// Generates:
619///
620/// ```rust,ignore
621/// #[napi]
622/// pub fn pi() -> JsExpression {
623/// JsExpression {
624/// inner: Expression::pi(),
625/// }
626/// }
627/// ```
628#[proc_macro]
629pub fn generate_nodejs_constant_binding(input: TokenStream) -> TokenStream {
630 let name = parse_macro_input!(input as Ident);
631
632 let expanded = quote! {
633 #[napi_derive::napi]
634 pub fn #name() -> crate::JsExpression {
635 use mathhook_core::Expression;
636 crate::JsExpression {
637 inner: Expression::#name(),
638 }
639 }
640 };
641
642 TokenStream::from(expanded)
643}
644
645#[cfg(test)]
646mod tests {
647 #[test]
648 fn test_macros_module_compiles() {
649 println!("All macro definitions compiled successfully");
650 }
651}