Skip to main content

arael_sym/
lib.rs

1//! Symbolic math library for expression trees, automatic differentiation,
2//! simplification, and code generation.
3//!
4//! `arael-sym` provides a lightweight computer algebra system built around a
5//! reference-counted expression tree ([`E`]).  Expressions are constructed from
6//! symbols and constants, combined with standard arithmetic operators (which
7//! auto-simplify), and then differentiated, evaluated, pretty-printed, or
8//! compiled to Rust source code.
9//!
10//! This crate is the symbolic engine behind the
11//! [`arael`](https://docs.rs/arael) optimization framework, where it powers
12//! compile-time constraint differentiation and code generation. It can also
13//! be used independently for any symbolic math task.
14//!
15//! # Scope and limitations
16//!
17//! `arael-sym` is focused on what's needed for nonlinear optimization:
18//! scalar expressions, differentiation, and code generation. Compared to
19//! a full CAS like Python's SymPy, it does **not** support:
20//!
21//! - Symbolic integration
22//! - Equation solving (solve for x)
23//! - Symbolic matrix algebra (symbolic determinant, inverse, eigenvalues)
24//! - Polynomial factoring, GCD, partial fractions
25//! - Limits, series expansion, Taylor series
26//! - Assumptions / domain reasoning (positive, real, integer)
27//! - Pattern matching / rewrite rules
28//! - Pretty-printing of intermediate simplification steps
29//!
30//! # Examples
31//!
32//! See [`docs/SYM.md`](https://github.com/harakas/arael/blob/master/docs/SYM.md)
33//! for the full reference with worked examples for every feature,
34//! and [`examples/sym_demo.rs`](https://github.com/harakas/arael/blob/master/examples/sym_demo.rs)
35//! for a runnable walkthrough (`cargo run --example sym_demo`).
36//! The tour below hits the high points.
37//!
38//! ## Basics
39//!
40//! The [`symbols!`] macro expands each bare identifier to
41//! `symbol("<name>")` and returns a tuple -- you write the name once
42//! instead of twice per variable. The [`sym!`] macro auto-inserts
43//! `.clone()` on every reused variable so the body reads as natural
44//! math without ownership boilerplate.
45//!
46//! Every expression has type [`E`], defined as
47//! `struct E(Rc<Expr>)`. Cloning is cheap (a reference-count bump) --
48//! the `.clone()` calls `sym!` inserts don't duplicate the
49//! expression tree.
50//!
51//! ```
52//! use arael_sym::*;
53//! let result = sym! {
54//!     let (x, y) = symbols!(x, y);
55//!     let f = x * y - 1.0 + pow(x, 2.0);
56//!     format!("{}", f)
57//! };
58//! assert_eq!(result, "x * y + x^2 - 1");
59//! ```
60//!
61//! ## Differentiation
62//!
63//! ```
64//! use arael_sym::*;
65//! let result = sym! {
66//!     let x = symbol("x");
67//!     let f = sin(x) * x;
68//!     // Product rule + chain rule applied automatically:
69//!     format!("{}", f.diff(x))
70//! };
71//! assert_eq!(result, "x * cos(x) + sin(x)");
72//! ```
73//!
74//! ## Evaluation
75//!
76//! ```
77//! use arael_sym::*;
78//! let val = sym! {
79//!     let x = symbol("x");
80//!     let f = x * x + 1.0;
81//!     let vars = std::collections::HashMap::from([("x", 3.0)]);
82//!     f.eval(&vars).unwrap()
83//! };
84//! assert_eq!(val, 10.0);
85//! ```
86//!
87//! ## Code generation
88//!
89//! ```
90//! use arael_sym::*;
91//! let (code1, code2) = sym! {
92//!     let (x, y) = symbols!(x, y);
93//!     let f = sin(x) + 1.0;
94//!     let g = atan2(y, x);
95//!     (f.to_rust("f64"), g.to_rust("f32"))
96//! };
97//! assert_eq!(code1, "x.sin() + 1.0_f64");
98//! assert_eq!(code2, "y.atan2(x)");
99//! ```
100//!
101//! ## Common Subexpression Elimination (CSE)
102//!
103//! ```
104//! use arael_sym::*;
105//! sym! {
106//!     let x = symbol("x");
107//!     let shared = sin(x) * cos(x);
108//!     let e1 = shared + 1.0;
109//!     let e2 = shared * 2.0;
110//!     let (intermediates, simplified) = cse(&[e1, e2]);
111//!     for (name, val) in &intermediates {
112//!         println!("let {} = {};", name, val);
113//!     }
114//!     for s in &simplified {
115//!         println!("{}", s);
116//!     }
117//! };
118//! // Output:
119//! //   let __x0 = cos(x) * sin(x);
120//! //   __x0 + 1
121//! //   2 * __x0
122//! ```
123//!
124//! ## Vectors and Matrices
125//!
126//! ```
127//! use arael_sym::*;
128//! let dot = sym! {
129//!     let (x, y, z) = symbols!(x, y, z);
130//!     let v = SymVec::new([x, y, z]);
131//!     let w = SymVec::new([1.0, 2.0, 3.0]);
132//!     format!("{}", v.dot(&w))
133//! };
134//! assert_eq!(dot, "x + 2 * y + 3 * z");
135//! ```
136//!
137//! ## Jacobian
138//!
139//! ```
140//! use arael_sym::*;
141//! let (j00, j01, j10, j11) = sym! {
142//!     let (x, y) = symbols!(x, y);
143//!     let f = vec![x * y, sin(x) + y];
144//!     let j = jacobian(&f, &["x", "y"]);
145//!     // j is 2x2: [[df0/dx, df0/dy], [df1/dx, df1/dy]]
146//!     (format!("{}", j.get(0, 0)),
147//!      format!("{}", j.get(0, 1)),
148//!      format!("{}", j.get(1, 0)),
149//!      format!("{}", j.get(1, 1)))
150//! };
151//! assert_eq!(j00, "y");      // d(x*y)/dx
152//! assert_eq!(j01, "x");      // d(x*y)/dy
153//! assert_eq!(j10, "cos(x)"); // d(sin(x)+y)/dx
154//! assert_eq!(j11, "1");      // d(sin(x)+y)/dy
155//! ```
156//!
157//!
158//! ## Parsing
159//!
160//! ```
161//! use arael_sym::*;
162//! let f = parse("sin(x)^2 + cos(x)^2").unwrap();
163//! assert_eq!(format!("{}", f), "sin(x)^2 + cos(x)^2");
164//!
165//! let vars = std::collections::HashMap::from([("x", 1.0)]);
166//! assert!((f.eval(&vars).unwrap() - 1.0).abs() < 1e-10);
167//! ```
168//!
169//! ## Named constants
170//!
171//! Named constants survive simplification (unlike numeric `Const` which may
172//! be folded away). Built-in: [`pi`], [`epsilon`], [`euler`]. Custom
173//! constants via [`named_const`]. The [`sym!`] macro accepts `pi` and
174//! `epsilon` as bare identifiers.
175//!
176//! ```
177//! use arael_sym::*;
178//! sym! {
179//!     let x = symbol("x");
180//!     let f = x * x + epsilon;           // bare identifier, no parens needed
181//!     assert_eq!(format!("{}", f), "x^2 + epsilon");
182//!     assert_eq!(format!("{}", sin(pi).simplify()), "0");
183//!     assert_eq!(format!("{}", cos(pi).simplify()), "-1");
184//!     assert_eq!(format!("{}", ln(euler()).simplify()), "1");
185//! };
186//! ```
187//!
188//! ## Identity and evaluation order
189//!
190//! The simplifier flattens and reorders additive terms, which can cause
191//! floating-point cancellation in generated code. For example,
192//! `1 - x^2 + epsilon^2` might be reordered to `-x^2 + epsilon^2 + 1`,
193//! and at `x=1` the tiny `epsilon^2` gets absorbed into `-1 + 1` before
194//! it can contribute.
195//!
196//! The [`identity`] function acts as a barrier: `identity(expr)` evaluates
197//! to `expr` and differentiates as `1`, but the simplifier cannot reorder
198//! terms across it. Codegen wraps the body in parentheses to preserve
199//! evaluation order in the generated Rust code.
200//!
201//! ```
202//! use arael_sym::*;
203//! sym! {
204//!     let x = symbol("x");
205//!     // Without identity: terms may reorder, epsilon^2 lost at x=1
206//!     // With identity: (1 - x^2) evaluates first, then epsilon^2 is added
207//!     let safe = identity(1.0 - x * x) + epsilon * epsilon;
208//!     let code = safe.to_rust("f64");
209//!     // Body is wrapped in parens in generated code
210//!     assert!(code.contains("(-x.powf(2.0_f64) + 1.0_f64)"));
211//! };
212//! ```
213//!
214//! This pattern is used internally by [`safe_asin`] and [`safe_acos`] to
215//! keep `epsilon^2` from being lost to floating-point cancellation in the
216//! derivative `1/sqrt(1 - x^2 + epsilon^2)`.
217//!
218//! ## Custom functions
219//!
220//! Define reusable symbolic functions with automatic differentiation.
221//! The factory functions return closures that can be called like regular
222//! functions.
223//!
224//! ```
225//! use arael_sym::*;
226//! sym! {
227//!     let x = symbol("x");
228//!     let square = simple_func1("square", |t| t * t);
229//!     let f = square(x + 1.0);
230//!     assert_eq!(format!("{}", f), "square(x + 1)");
231//!     assert_eq!(format!("{}", f.diff(x)), "2 * (x + 1)");
232//!     // Codegen inlines the expanded body:
233//!     assert_eq!(f.to_rust("f64"), "(x + 1.0_f64).powf(2.0_f64)");
234//! };
235//! ```
236//!
237//! ## Extern functions
238//!
239//! When a function's runtime behavior differs from its derivative (e.g.
240//! angle normalization), use extern functions. They generate a function
241//! call in codegen while differentiating through a separate symbolic body.
242//!
243//! ```
244//! use arael_sym::*;
245//! fn my_angle_diff(args: &[f64]) -> f64 {
246//!     let d = args[0] - args[1];
247//!     d - (2.0 * std::f64::consts::PI)
248//!       * (d / (2.0 * std::f64::consts::PI) + 0.5).floor()
249//! }
250//! sym! {
251//!     // codegen emits my_mod::angle_diff(a, b)
252//!     // differentiation uses gradient of (a - b)
253//!     // eval uses my_angle_diff
254//!     let angle_diff = extern_func2("angle_diff", "my_mod::angle_diff",
255//!         grad2(|a, b| a - b), my_angle_diff);
256//!     let (x, y) = symbols!(x, y);
257//!     let f = angle_diff(x * x, y);
258//!     assert_eq!(format!("{}", f.diff(x)), "2 * x");
259//!     assert_eq!(f.to_rust("f64"), "my_mod::angle_diff(x.powf(2.0_f64), y)");
260//!     // eval uses the native eval_fn:
261//!     let vars = std::collections::HashMap::from([("x", 0.0), ("y", 6.283185307179586)]);
262//!     assert!(f.eval(&vars).unwrap().abs() < 1e-10); // 0 - 2pi wraps to 0
263//! };
264//! ```
265//!
266//! Built-in [`rad_diff`] and [`rad_sum`] are extern functions with
267//! rollover-safe angle normalization to \[-pi, pi\].
268//!
269//! ## Heaviside and clamp
270//!
271//! Pragmatic functions for optimization near numerical boundaries.
272//! `heaviside` has derivative 0 everywhere (not Dirac delta).
273//! `clamp` has pass-through derivative (as if clamping were not there).
274//!
275//! ```
276//! use arael_sym::*;
277//! sym! {
278//!     // clamp prevents NaN from asin outside [-1, 1]
279//!     // Note: derivative still diverges at +/-1. One can prevent it
280//!     // by providing custom derivatives with simple_func1_derivs as
281//!     // is done in the built-in safe_asin().
282//!     let my_asin = simple_func1("my_asin",
283//!         |t| asin(clamp(t, -1.0, 1.0)));
284//!     let x = symbol("x");
285//!     let f = my_asin(x);
286//!     let vars = std::collections::HashMap::from([("x", 1.5)]);
287//!     // Clamped to asin(1.0) = pi/2, no NaN
288//!     let val = f.eval(&vars).unwrap();
289//!     assert!((val - std::f64::consts::FRAC_PI_2).abs() < 1e-10);
290//! };
291//! ```
292
293#![allow(clippy::should_implement_trait)]
294
295mod diff;
296mod eval;
297mod fmt;
298mod simplify;
299mod linalg;
300mod parse;
301pub mod geo;
302pub mod cse;
303
304use std::hash::{Hash, Hasher};
305use std::rc::Rc;
306
307/// Symbolic expression wrapper.
308///
309/// Reference-counted (cheap to clone).  All arithmetic operations auto-simplify.
310/// Dereferences to [`Expr`] so all methods on `Expr` (e.g. [`Expr::diff`],
311/// [`Expr::eval`], [`Expr::simplify`]) are available directly on `E`.
312#[derive(Clone, PartialEq)]
313pub struct E(Rc<Expr>);
314
315impl Eq for E {}
316
317impl E {
318    fn new(expr: Expr) -> E {
319        E(Rc::new(expr))
320    }
321
322    /// Collect all symbol names referenced in this expression.
323    pub fn symbols(&self) -> std::collections::HashSet<String> {
324        let mut out = std::collections::HashSet::new();
325        self.collect_symbols(&mut out);
326        out
327    }
328
329    fn collect_symbols(&self, out: &mut std::collections::HashSet<String>) {
330        match &*self.0 {
331            Expr::Sym(s) => { out.insert(s.clone()); }
332            Expr::Const(_) | Expr::NamedConst { .. } => {}
333            Expr::Neg(a) | Expr::Sin(a) | Expr::Cos(a) | Expr::Tan(a)
334            | Expr::Asin(a) | Expr::Acos(a) | Expr::Atan(a)
335            | Expr::Sinh(a) | Expr::Cosh(a) | Expr::Tanh(a)
336            | Expr::Exp(a) | Expr::Ln(a) | Expr::Log2(a) | Expr::Log10(a)
337            | Expr::Sqrt(a) | Expr::Abs(a)
338            | Expr::Heaviside(a) => { a.collect_symbols(out); }
339            Expr::Add(a, b) | Expr::Sub(a, b) | Expr::Mul(a, b)
340            | Expr::Div(a, b) | Expr::Pow(a, b) | Expr::Atan2(a, b) => {
341                a.collect_symbols(out);
342                b.collect_symbols(out);
343            }
344            Expr::Clamp(a, b, c) => {
345                a.collect_symbols(out);
346                b.collect_symbols(out);
347                c.collect_symbols(out);
348            }
349            Expr::Func { args, .. } => {
350                for arg in args { arg.collect_symbols(out); }
351            }
352        }
353    }
354
355    /// Substitute symbols in this expression. Each pair `(from, to)` replaces
356    /// occurrences of `from` with `to`. Returns a new expression.
357    pub fn substitute(&self, subs: &[(E, E)]) -> E {
358        for (from, to) in subs {
359            if self == from { return to.clone(); }
360        }
361        match &*self.0 {
362            Expr::Sym(_) | Expr::Const(_) | Expr::NamedConst { .. } => self.clone(),
363            Expr::Neg(a) => -a.substitute(subs),
364            Expr::Add(a, b) => a.substitute(subs) + b.substitute(subs),
365            Expr::Sub(a, b) => a.substitute(subs) - b.substitute(subs),
366            Expr::Mul(a, b) => a.substitute(subs) * b.substitute(subs),
367            Expr::Div(a, b) => a.substitute(subs) / b.substitute(subs),
368            Expr::Pow(a, b) => pow(a.substitute(subs), b.substitute(subs)),
369            Expr::Sin(a) => sin(a.substitute(subs)),
370            Expr::Cos(a) => cos(a.substitute(subs)),
371            Expr::Tan(a) => tan(a.substitute(subs)),
372            Expr::Asin(a) => asin(a.substitute(subs)),
373            Expr::Acos(a) => acos(a.substitute(subs)),
374            Expr::Atan(a) => atan(a.substitute(subs)),
375            Expr::Atan2(a, b) => atan2(a.substitute(subs), b.substitute(subs)),
376            Expr::Sinh(a) => sinh(a.substitute(subs)),
377            Expr::Cosh(a) => cosh(a.substitute(subs)),
378            Expr::Tanh(a) => tanh(a.substitute(subs)),
379            Expr::Exp(a) => exp(a.substitute(subs)),
380            Expr::Ln(a) => ln(a.substitute(subs)),
381            Expr::Log2(a) => log2(a.substitute(subs)),
382            Expr::Log10(a) => ln(a.substitute(subs)) / ln(constant(10.0)),
383            Expr::Sqrt(a) => sqrt(a.substitute(subs)),
384            Expr::Abs(a) => abs(a.substitute(subs)),
385            Expr::Heaviside(a) => heaviside(a.substitute(subs)),
386            Expr::Clamp(a, lo, hi) => clamp(a.substitute(subs), lo.substitute(subs), hi.substitute(subs)),
387            Expr::Func { name, params, kind, args } => {
388                let new_args = args.iter().map(|a| a.substitute(subs)).collect();
389                E::new(Expr::Func { name: name.clone(), params: params.clone(), kind: kind.clone(), args: new_args })
390            }
391        }
392    }
393}
394
395impl std::ops::Deref for E {
396    type Target = Expr;
397    fn deref(&self) -> &Expr {
398        &self.0
399    }
400}
401
402impl AsRef<Expr> for E {
403    fn as_ref(&self) -> &Expr {
404        &self.0
405    }
406}
407
408/// Expression AST node.
409///
410/// Normally constructed via [`symbol`], [`constant`], and the free-standing
411/// math functions (e.g. [`sin`], [`cos`], [`pow`]) rather than directly.
412#[derive(Debug, Clone, PartialEq)]
413pub enum Expr {
414    /// Named symbolic variable.
415    Sym(String),
416    /// Numeric constant.
417    Const(f64),
418    /// Unary negation.
419    Neg(E),
420    /// Addition.
421    Add(E, E),
422    /// Subtraction.
423    Sub(E, E),
424    /// Multiplication.
425    Mul(E, E),
426    /// Division.
427    Div(E, E),
428    /// Exponentiation (base^exponent).
429    Pow(E, E),
430    /// Sine.
431    Sin(E),
432    /// Cosine.
433    Cos(E),
434    /// Tangent.
435    Tan(E),
436    /// Arcsine.
437    Asin(E),
438    /// Arccosine.
439    Acos(E),
440    /// Arctangent.
441    Atan(E),
442    /// Two-argument arctangent (atan2(y, x)).
443    Atan2(E, E),
444    /// Hyperbolic sine.
445    Sinh(E),
446    /// Hyperbolic cosine.
447    Cosh(E),
448    /// Hyperbolic tangent.
449    Tanh(E),
450    /// Exponential (e^x).
451    Exp(E),
452    /// Natural logarithm.
453    Ln(E),
454    /// Base-2 logarithm.
455    Log2(E),
456    /// Base-10 logarithm.
457    Log10(E),
458    /// Square root.
459    Sqrt(E),
460    /// Absolute value.
461    Abs(E),
462    /// Heaviside step function: 0 if x < 0, 1 if x >= 0. Derivative is 0.
463    Heaviside(E),
464    /// Clamp value to [lo, hi]. Derivative passes through (= d(val)/dvar).
465    Clamp(E, E, E),
466    /// Named constant (pi, epsilon, e, or user-defined).
467    /// Survives simplification (unlike Const which may be folded away).
468    NamedConst {
469        name: String,
470        value: f64,
471        rust_f32: String,
472        rust_f64: String,
473        latex: String,
474    },
475    /// User-defined function application.
476    Func {
477        /// Function name (for display).
478        name: String,
479        /// Formal parameter names.
480        params: Vec<String>,
481        /// Function behavior (differentiation, codegen, eval).
482        kind: FuncKind,
483        /// Actual argument expressions.
484        args: Vec<E>,
485    },
486}
487
488/// Describes what kind of function behavior to use for differentiation,
489/// evaluation, and code generation.
490#[derive(Debug, Clone, PartialEq)]
491#[allow(unpredictable_function_pointer_comparisons)]
492pub enum FuncKind {
493    /// Body auto-differentiated. Body used for eval and codegen (inlined).
494    Symbolic { body: E },
495    /// Explicit per-argument derivatives. Body used for eval and codegen (inlined).
496    SymbolicDerivs { body: E, derivs: Vec<E> },
497    /// Explicit per-argument derivatives. Codegen emits `call_path(args...)`.
498    /// `eval_fn` used for eval (required).
499    Extern { derivs: Vec<E>, eval_fn: fn(&[f64]) -> f64, call_path: String },
500}
501
502impl FuncKind {
503    /// Body for auto-differentiation (Symbolic only).
504    pub fn auto_diff_body(&self) -> Option<&E> {
505        match self {
506            FuncKind::Symbolic { body } => Some(body),
507            _ => None,
508        }
509    }
510
511    /// Explicit per-argument derivatives (SymbolicDerivs and Extern).
512    pub fn derivs(&self) -> Option<&[E]> {
513        match self {
514            FuncKind::SymbolicDerivs { derivs, .. } | FuncKind::Extern { derivs, .. } => Some(derivs),
515            FuncKind::Symbolic { .. } => None,
516        }
517    }
518
519    /// Body for symbolic eval and codegen inlining (Symbolic variants).
520    pub fn body(&self) -> Option<&E> {
521        match self {
522            FuncKind::Symbolic { body } | FuncKind::SymbolicDerivs { body, .. } => Some(body),
523            FuncKind::Extern { .. } => None,
524        }
525    }
526
527    /// Native eval function (Extern only).
528    pub fn eval_fn(&self) -> Option<fn(&[f64]) -> f64> {
529        match self {
530            FuncKind::Extern { eval_fn, .. } => Some(*eval_fn),
531            _ => None,
532        }
533    }
534}
535
536impl Hash for FuncKind {
537    fn hash<H: Hasher>(&self, state: &mut H) {
538        std::mem::discriminant(self).hash(state);
539        match self {
540            FuncKind::Symbolic { body } => body.hash(state),
541            FuncKind::SymbolicDerivs { body, derivs } => {
542                body.hash(state);
543                derivs.hash(state);
544            }
545            FuncKind::Extern { derivs, eval_fn, call_path } => {
546                derivs.hash(state);
547                (*eval_fn as usize).hash(state);
548                call_path.hash(state);
549            }
550        }
551    }
552}
553
554impl Eq for Expr {}
555
556impl Hash for Expr {
557    fn hash<H: Hasher>(&self, state: &mut H) {
558        std::mem::discriminant(self).hash(state);
559        match self {
560            Expr::Sym(s) => s.hash(state),
561            Expr::Const(v) => v.to_bits().hash(state),
562            Expr::Neg(a) | Expr::Sin(a) | Expr::Cos(a) | Expr::Tan(a)
563            | Expr::Asin(a) | Expr::Acos(a) | Expr::Atan(a)
564            | Expr::Sinh(a) | Expr::Cosh(a) | Expr::Tanh(a)
565            | Expr::Exp(a) | Expr::Ln(a) | Expr::Log2(a) | Expr::Log10(a)
566            | Expr::Sqrt(a) | Expr::Abs(a)
567            | Expr::Heaviside(a) => a.hash(state),
568            Expr::Add(a, b) | Expr::Sub(a, b) | Expr::Mul(a, b)
569            | Expr::Div(a, b) | Expr::Pow(a, b) | Expr::Atan2(a, b) => {
570                a.hash(state);
571                b.hash(state);
572            }
573            Expr::Clamp(a, b, c) => {
574                a.hash(state);
575                b.hash(state);
576                c.hash(state);
577            }
578            Expr::NamedConst { name, value, .. } => {
579                name.hash(state);
580                value.to_bits().hash(state);
581            }
582            Expr::Func { name, params, kind, args } => {
583                name.hash(state);
584                params.hash(state);
585                kind.hash(state);
586                args.hash(state);
587            }
588        }
589    }
590}
591
592impl Hash for E {
593    fn hash<H: Hasher>(&self, state: &mut H) {
594        self.0.hash(state);
595    }
596}
597
598// --- Constructors ---
599
600/// Create a named symbolic variable.
601pub fn symbol(name: &str) -> E {
602    E::new(Expr::Sym(name.to_string()))
603}
604
605/// Types that can name a symbolic variable for operations that key
606/// into the expression tree by name -- `diff`, `subs`, `collect`.
607/// Implemented for `&str`, `String`, `&String`, and [`E`] (when it
608/// wraps a `Sym` node), so you can write `expr.diff("x")` or
609/// `expr.diff(&my_symbol)` and reach the same variable. The blanket
610/// `var_expr` default builds a fresh `Sym` node from the name;
611/// implementations on [`E`] override it to reuse the caller's handle
612/// and avoid an allocation.
613pub trait AsVarName {
614    /// Return the variable name as a string slice.
615    fn var_name(&self) -> &str;
616
617    /// Return an `E` representing this variable. Default: build a
618    /// fresh `Sym` node from `var_name()`.
619    fn var_expr(&self) -> E {
620        symbol(self.var_name())
621    }
622}
623
624impl AsVarName for &str {
625    fn var_name(&self) -> &str { self }
626}
627
628impl AsVarName for &&str {
629    fn var_name(&self) -> &str { self }
630}
631
632impl AsVarName for str {
633    fn var_name(&self) -> &str { self }
634}
635
636impl AsVarName for String {
637    fn var_name(&self) -> &str { self.as_str() }
638}
639
640impl AsVarName for &String {
641    fn var_name(&self) -> &str { self.as_str() }
642}
643
644impl AsVarName for &E {
645    fn var_name(&self) -> &str { (*self).var_name() }
646    fn var_expr(&self) -> E { (*self).clone() }
647}
648
649impl AsVarName for E {
650    fn var_name(&self) -> &str {
651        match self.as_ref() {
652            Expr::Sym(name) => name.as_str(),
653            _ => panic!("AsVarName::var_name: expected a symbol, got `{self}`"),
654        }
655    }
656    fn var_expr(&self) -> E { self.clone() }
657}
658
659/// Create several symbolic variables at once and return them as a
660/// tuple. Each identifier becomes a fresh [`E`] whose name is that
661/// identifier stringified, sparing the caller from writing the name
662/// twice per variable.
663///
664/// ```
665/// use arael_sym::*;
666/// let (x, y, z) = symbols!(x, y, z);
667/// assert_eq!(format!("{}", x * y + z), "x * y + z");
668/// ```
669///
670/// A trailing comma in the expansion makes the single-identifier
671/// form a 1-tuple (`(E,)`); for a single symbol [`symbol`] is
672/// usually the clearer spelling.
673#[macro_export]
674macro_rules! symbols {
675    ($($name:ident),+ $(,)?) => {
676        ( $( $crate::symbol(stringify!($name)) ),+ , )
677    };
678}
679
680/// Create a numeric constant.
681pub fn constant(val: f64) -> E {
682    E::new(Expr::Const(val))
683}
684
685impl From<f64> for E {
686    fn from(v: f64) -> E { constant(v) }
687}
688
689impl From<i64> for E {
690    fn from(v: i64) -> E { constant(v as f64) }
691}
692
693impl From<i32> for E {
694    fn from(v: i32) -> E { constant(v as f64) }
695}
696
697/// Create a named constant with explicit display, eval, codegen, and LaTeX representations.
698pub fn named_const(name: &str, value: f64, rust_f32: &str, rust_f64: &str, latex: &str) -> E {
699    E::new(Expr::NamedConst {
700        name: name.to_string(), value,
701        rust_f32: rust_f32.to_string(), rust_f64: rust_f64.to_string(),
702        latex: latex.to_string(),
703    })
704}
705
706/// $\pi = 3.14159\ldots$
707pub fn pi() -> E {
708    named_const("pi", std::f64::consts::PI,
709        "std::f32::consts::PI", "std::f64::consts::PI", "\\pi")
710}
711
712/// Machine epsilon $\epsilon$ (`f64::EPSILON` $\approx 2.22 \times 10^{-16}$).
713pub fn epsilon() -> E {
714    named_const("epsilon", f64::EPSILON,
715        "f32::EPSILON", "f64::EPSILON", "\\epsilon")
716}
717
718/// Euler's number $e = 2.71828\ldots$
719pub fn euler() -> E {
720    named_const("e", std::f64::consts::E,
721        "std::f32::consts::E", "std::f64::consts::E", "e")
722}
723
724/// Short alias for [`constant`]. Common in math notation.
725pub fn c(val: f64) -> E { constant(val) }
726
727/// Symbolic sine function.
728pub fn sin(e: E) -> E { E::new(Expr::Sin(e)) }
729/// Symbolic cosine function.
730pub fn cos(e: E) -> E { E::new(Expr::Cos(e)) }
731/// Symbolic tangent function.
732pub fn tan(e: E) -> E { E::new(Expr::Tan(e)) }
733/// Symbolic arcsine function.
734pub fn asin(e: E) -> E { E::new(Expr::Asin(e)) }
735/// Symbolic arccosine function.
736pub fn acos(e: E) -> E { E::new(Expr::Acos(e)) }
737/// Symbolic arctangent function.
738pub fn atan(e: E) -> E { E::new(Expr::Atan(e)) }
739/// Symbolic two-argument arctangent: atan2(y, x).
740pub fn atan2(y: E, x: E) -> E { E::new(Expr::Atan2(y, x)) }
741/// Symbolic hyperbolic sine function.
742pub fn sinh(e: E) -> E { E::new(Expr::Sinh(e)) }
743/// Symbolic hyperbolic cosine function.
744pub fn cosh(e: E) -> E { E::new(Expr::Cosh(e)) }
745/// Symbolic hyperbolic tangent function.
746pub fn tanh(e: E) -> E { E::new(Expr::Tanh(e)) }
747/// Symbolic exponential function (e^x).
748pub fn exp(e: E) -> E { E::new(Expr::Exp(e)) }
749/// Symbolic natural logarithm.
750pub fn ln(e: E) -> E { E::new(Expr::Ln(e)) }
751/// Symbolic base-2 logarithm.
752pub fn log2(e: E) -> E { E::new(Expr::Log2(e)) }
753/// Symbolic base-10 logarithm.
754pub fn log10(e: E) -> E { E::new(Expr::Log10(e)) }
755/// Symbolic square root.
756pub fn sqrt(e: E) -> E { E::new(Expr::Sqrt(e)) }
757/// Symbolic absolute value.
758pub fn abs(e: E) -> E { E::new(Expr::Abs(e)) }
759/// Symbolic Heaviside step function: 0 if x < 0, 1 if x >= 0.
760pub fn heaviside(e: E) -> E { E::new(Expr::Heaviside(e)) }
761/// Symbolic clamp: clamp value to [lo, hi]. Derivative passes through.
762/// Accepts `impl Into<E>` on all three args so bare numeric bounds
763/// compose naturally: `clamp(x, -1.0, 1.0)`.
764pub fn clamp(val: impl Into<E>, lo: impl Into<E>, hi: impl Into<E>) -> E {
765    E::new(Expr::Clamp(val.into(), lo.into(), hi.into()))
766}
767/// Symbolic power function. Auto-simplifies (e.g. x^0 = 1, x^1 = x).
768/// Accepts `impl Into<E>` for both args so bare numeric literals
769/// compose naturally: `pow(x, 2.0)`, `pow(x, 3)`.
770pub fn pow(base: impl Into<E>, exponent: impl Into<E>) -> E {
771    E::new(Expr::Pow(base.into(), exponent.into())).simplify()
772}
773
774// ---------------------------------------------------------------------------
775// Name-based function lookup
776//
777// Users that parse an expression tree (for example arael-macros turning a
778// constraint body or a fit expression into an arael_sym::E) need to map
779// function-name tokens like "sin", "atan2", "clamp" to the actual arael-sym
780// function. Keeping the authoritative list here (next to the functions
781// themselves) means external dispatchers don't have to duplicate it and
782// new functions land everywhere for free.
783// ---------------------------------------------------------------------------
784
785/// A scalar function exported by arael-sym, discovered by name. Tagged by
786/// arity so callers can validate the argument count without a second table.
787#[derive(Clone, Copy)]
788pub enum FunctionRef {
789    Unary(fn(E) -> E),
790    Binary(fn(E, E) -> E),
791    Ternary(fn(E, E, E) -> E),
792}
793
794/// The authoritative table of scalar functions arael-sym exposes by name.
795/// Adding a new `pub fn foo` above should add an entry here as well; every
796/// string-based dispatcher (the parser, the macro's constraint/fit
797/// dispatchers, user-facing autocompleters) reads from this one table.
798pub const FUNCTIONS: &[(&str, FunctionRef)] = &[
799    // Unary trig
800    ("sin", FunctionRef::Unary(sin)),
801    ("cos", FunctionRef::Unary(cos)),
802    ("tan", FunctionRef::Unary(tan)),
803    ("asin", FunctionRef::Unary(asin)),
804    ("acos", FunctionRef::Unary(acos)),
805    ("atan", FunctionRef::Unary(atan)),
806    ("sinh", FunctionRef::Unary(sinh)),
807    ("cosh", FunctionRef::Unary(cosh)),
808    ("tanh", FunctionRef::Unary(tanh)),
809    // Unary exp / log / pow-ish
810    ("exp", FunctionRef::Unary(exp)),
811    ("ln", FunctionRef::Unary(ln)),
812    ("log2", FunctionRef::Unary(log2)),
813    ("log10", FunctionRef::Unary(log10)),
814    ("sqrt", FunctionRef::Unary(sqrt)),
815    ("abs", FunctionRef::Unary(abs)),
816    ("heaviside", FunctionRef::Unary(heaviside)),
817    // Unary "safe" variants
818    ("identity", FunctionRef::Unary(identity)),
819    ("safe_sqrt", FunctionRef::Unary(safe_sqrt)),
820    ("safe_asin", FunctionRef::Unary(safe_asin)),
821    ("safe_acos", FunctionRef::Unary(safe_acos)),
822    // Binary
823    ("atan2", FunctionRef::Binary(atan2)),
824    ("pow", FunctionRef::Binary(pow)),
825    ("safe_atan2", FunctionRef::Binary(safe_atan2)),
826    ("rad_diff", FunctionRef::Binary(rad_diff)),
827    ("rad_sum", FunctionRef::Binary(rad_sum)),
828    // Ternary
829    ("clamp", FunctionRef::Ternary(clamp)),
830];
831
832/// Look up a scalar function by its conventional name. Returns `None` for
833/// unrecognized names -- callers typically emit a user-facing error in that
834/// case.
835pub fn function_by_name(name: &str) -> Option<FunctionRef> {
836    FUNCTIONS.iter().find(|(n, _)| *n == name).map(|(_, f)| *f)
837}
838
839/// Iterate over the names of every scalar function arael-sym exposes.
840/// Useful for autocomplete and "what functions are available?" queries.
841pub fn function_names() -> impl Iterator<Item = &'static str> {
842    FUNCTIONS.iter().map(|(n, _)| *n)
843}
844
845// ---------------------------------------------------------------------------
846// FunctionBag -- extensible registry of user-defined functions
847// ---------------------------------------------------------------------------
848
849/// An extensible registry of user-defined symbolic functions, used by
850/// [`parse::parse_with_functions`] to make runtime-constructed
851/// functions recognisable by the string parser.
852///
853/// Built-in functions (`sin`, `cos`, `clamp`, etc.) are *not* stored in
854/// the bag -- the parser falls back to [`function_by_name`] for any
855/// name the bag doesn't carry, so built-ins are always available
856/// regardless of what's in the bag. An empty bag means "built-ins
857/// only", which is what [`parse::parse`] uses.
858///
859/// Names registered in the bag shadow built-ins with the same name.
860///
861/// ## Registering a function
862///
863/// Pick the entry point that fits how you have the function in hand:
864///
865/// - [`add1`](Self::add1) / [`add2`](Self::add2) -- register a closure
866///   of arity 1 / 2, typically one produced by [`simple_func1`] /
867///   [`simple_func2`] / [`extern_func1`] / [`extern_func2`]. The bag
868///   invokes it once with placeholder symbols to extract name, params,
869///   and kind.
870/// - [`addN`](Self::addN) -- register an n-ary closure over `Vec<E>`.
871///   Pairs with [`simple_func`] / [`simple_func_derivs`] /
872///   [`extern_func`]. No upper arity bound.
873/// - [`add`](Self::add) -- register an already-formed `Expr::Func`
874///   value (for example, the output of
875///   [`simple_func1`]`("sq", |t| t*t)(symbol("x"))`).
876/// - [`add_symbolic`](Self::add_symbolic) -- when the body is an
877///   already-built `E` (e.g. from [`parse::parse`]) and you don't want
878///   to wrap it in a closure. Body is auto-differentiated.
879/// - [`add_with_kind`](Self::add_with_kind) -- escape hatch: name,
880///   parameter list, and a hand-built [`FuncKind`] directly.
881///
882/// ## Variable / parameter shadowing
883///
884/// Parameters declared when the function is registered always shadow
885/// variables of the same name in the caller's eval context. For
886/// example, after:
887///
888/// ```ignore
889/// let mut bag = FunctionBag::new();
890/// bag.add_symbolic("sq", vec!["x".into()], parse("x*x").unwrap());
891/// let e = parse_with_functions("sq(3)", &bag).unwrap();
892/// let vars = [("x", 5.0)].into_iter().collect();
893/// let r = e.eval(&vars).unwrap(); // 9.0, not 25.0
894/// ```
895///
896/// the outer `x = 5.0` is shadowed inside the function body by the
897/// formal parameter `x = 3.0` for the duration of the call.
898///
899/// ## See also
900///
901/// [`examples/calc_demo.rs`](https://github.com/harakas/arael/blob/master/examples/calc_demo.rs)
902/// is a bc-style REPL calculator built on top of `FunctionBag` +
903/// [`parse::parse_with_functions`]: variables, runtime function
904/// definitions (`name(args) = expr`), `vars` / `funcs` listings, and
905/// readline-style history.
906#[derive(Clone)]
907pub struct FunctionBag {
908    // Name -> (params, kind). Args are filled in at call time to build
909    // a fresh Expr::Func per invocation. Mirrors Expr::Func directly.
910    table: std::collections::HashMap<String, BagFunction>,
911}
912
913#[derive(Clone)]
914struct BagFunction {
915    params: std::vec::Vec<String>,
916    kind: FuncKind,
917}
918
919impl Default for FunctionBag {
920    fn default() -> Self { Self::new() }
921}
922
923fn extract_func_template(e: E, source: &str) -> Result<(String, std::vec::Vec<String>, FuncKind), String> {
924    match (*e.0).clone() {
925        Expr::Func { name, params, kind, .. } => Ok((name, params, kind)),
926        _ => Err(format!("{source}: expected Expr::Func, got a different expression")),
927    }
928}
929
930impl FunctionBag {
931    /// Empty bag. Built-in functions remain available via the parser's
932    /// fallback lookup; only user-added functions go here.
933    pub fn new() -> Self {
934        Self { table: std::collections::HashMap::new() }
935    }
936
937    /// Register a pre-built `Expr::Func` value. Use when you already
938    /// have an `E` (for example by calling one of the
939    /// [`simple_func1`] / [`simple_func2`] / [`simple_func`] /
940    /// [`extern_func1`] / [`extern_func2`] / [`extern_func`]
941    /// constructors on placeholder args).
942    ///
943    /// For registering closures directly, use [`add1`](Self::add1) /
944    /// [`add2`](Self::add2) / [`addN`](Self::addN).
945    ///
946    /// Returns `Err` if `e` is not an `Expr::Func`.
947    pub fn add(&mut self, e: E) -> Result<(), String> {
948        let (name, params, kind) = extract_func_template(e, "FunctionBag::add")?;
949        self.table.insert(name, BagFunction { params, kind });
950        Ok(())
951    }
952
953    /// Register a unary closure. The bag invokes it once with a
954    /// placeholder symbol to extract `(name, params, kind)`.
955    ///
956    /// ```ignore
957    /// bag.add1(simple_func1("sq", |t| t.clone() * t)).unwrap();
958    /// ```
959    pub fn add1<F>(&mut self, f: F) -> Result<(), String>
960    where F: FnOnce(E) -> E
961    {
962        let e = f(symbol("__a0"));
963        let (name, params, kind) = extract_func_template(e, "FunctionBag::add1")?;
964        self.table.insert(name, BagFunction { params, kind });
965        Ok(())
966    }
967
968    /// Register a binary closure.
969    ///
970    /// ```ignore
971    /// bag.add2(simple_func2("hypot",
972    ///     |a, b| sqrt(a.clone()*a + b.clone()*b))).unwrap();
973    /// ```
974    pub fn add2<F>(&mut self, f: F) -> Result<(), String>
975    where F: FnOnce(E, E) -> E
976    {
977        let e = f(symbol("__a0"), symbol("__a1"));
978        let (name, params, kind) = extract_func_template(e, "FunctionBag::add2")?;
979        self.table.insert(name, BagFunction { params, kind });
980        Ok(())
981    }
982
983    /// Register an n-ary closure. Pairs with [`simple_func`] /
984    /// [`simple_func_derivs`] / [`extern_func`] for arities >= 3 and
985    /// for functions whose arity is known only at runtime. The
986    /// closure takes `Vec<E>` to match the shape those constructors
987    /// return (`impl Fn(Vec<E>) -> E`).
988    ///
989    /// ```ignore
990    /// bag.addN(4, simple_func("blend", 4, |args: Vec<E>|
991    ///     args[0].clone() + args[1].clone() + args[2].clone() + args[3].clone()
992    /// )).unwrap();
993    /// ```
994    #[allow(non_snake_case)]
995    pub fn addN<F>(&mut self, arity: usize, f: F) -> Result<(), String>
996    where F: FnOnce(std::vec::Vec<E>) -> E
997    {
998        let placeholders: std::vec::Vec<E> =
999            (0..arity).map(|i| symbol(&format!("__a{i}"))).collect();
1000        let e = f(placeholders);
1001        let (name, params, kind) = extract_func_template(e, "FunctionBag::addN")?;
1002        self.table.insert(name, BagFunction { params, kind });
1003        Ok(())
1004    }
1005
1006    /// Convenience: register a symbolic function from an explicit
1007    /// `name`, parameter list, and body `E` whose free symbols match
1008    /// the params. Use this when you have the body as an already-built
1009    /// expression (e.g. from [`parse`]) rather than as a closure.
1010    pub fn add_symbolic(&mut self, name: impl Into<String>, params: std::vec::Vec<String>, body: E) {
1011        self.table.insert(
1012            name.into(),
1013            BagFunction { params, kind: FuncKind::Symbolic { body } },
1014        );
1015    }
1016
1017    /// Direct form: register a function from name + parameters + kind.
1018    /// Most callers should prefer [`add`](Self::add) (closures / E)
1019    /// or [`add_symbolic`](Self::add_symbolic) (parsed body) -- this
1020    /// is the escape hatch for building an unusual `FuncKind` by hand.
1021    pub fn add_with_kind(
1022        &mut self,
1023        name: impl Into<String>,
1024        params: std::vec::Vec<String>,
1025        kind: FuncKind,
1026    ) {
1027        self.table.insert(name.into(), BagFunction { params, kind });
1028    }
1029
1030    /// Remove a function by name. Returns `true` if it was present.
1031    /// Does not affect built-ins.
1032    pub fn remove(&mut self, name: &str) -> bool {
1033        self.table.remove(name).is_some()
1034    }
1035
1036    /// Is this name registered in the bag? Does *not* consider
1037    /// built-ins.
1038    pub fn contains(&self, name: &str) -> bool {
1039        self.table.contains_key(name)
1040    }
1041
1042    /// Collect all names registered in the bag. Order is unspecified.
1043    pub fn names(&self) -> std::vec::Vec<String> {
1044        self.table.keys().cloned().collect()
1045    }
1046
1047    /// Iterate over `(name, arity)` pairs for every function in the
1048    /// bag. Same data as [`names`](Self::names) with arity attached.
1049    pub fn entries(&self) -> impl Iterator<Item = (&str, usize)> {
1050        self.table.iter().map(|(k, v)| (k.as_str(), v.params.len()))
1051    }
1052
1053    /// Look up a function's parameter names and kind. Returns `None`
1054    /// if `name` isn't in the bag. Useful for pretty-printing or
1055    /// re-creating an `Expr::Func` outside the parser.
1056    pub fn get_info(&self, name: &str) -> Option<(&[String], &FuncKind)> {
1057        let f = self.table.get(name)?;
1058        Some((&f.params, &f.kind))
1059    }
1060
1061    /// Build an `Expr::Func` by looking up `name` in this bag and
1062    /// pairing it with `args`. Returns `None` if `name` is not
1063    /// registered; returns `Some(Err(..))` if the arity disagrees.
1064    /// `None` means the name isn't in the bag -- callers that want
1065    /// built-ins as a fallback should route through
1066    /// [`parse::parse_with_functions`] or
1067    /// [`function_by_name`](crate::function_by_name).
1068    pub fn call(&self, name: &str, args: &[E]) -> Option<Result<E, String>> {
1069        let f = self.table.get(name)?;
1070        if args.len() != f.params.len() {
1071            return Some(Err(format!(
1072                "{} expects {} argument(s), got {}",
1073                name, f.params.len(), args.len()
1074            )));
1075        }
1076        let func = E::new(Expr::Func {
1077            name: name.to_string(),
1078            params: f.params.clone(),
1079            kind: f.kind.clone(),
1080            args: args.to_vec(),
1081        });
1082        Some(Ok(func))
1083    }
1084}
1085
1086// --- Operator overloads for E (auto-simplify like SymPy) ---
1087
1088impl std::ops::Add for E {
1089    type Output = E;
1090    fn add(self, rhs: E) -> E {
1091        E::new(Expr::Add(self, rhs)).simplify()
1092    }
1093}
1094
1095impl std::ops::Sub for E {
1096    type Output = E;
1097    fn sub(self, rhs: E) -> E {
1098        E::new(Expr::Sub(self, rhs)).simplify()
1099    }
1100}
1101
1102impl std::ops::Mul for E {
1103    type Output = E;
1104    fn mul(self, rhs: E) -> E {
1105        E::new(Expr::Mul(self, rhs)).simplify()
1106    }
1107}
1108
1109impl std::ops::Div for E {
1110    type Output = E;
1111    fn div(self, rhs: E) -> E {
1112        E::new(Expr::Div(self, rhs)).simplify()
1113    }
1114}
1115
1116impl std::ops::Neg for E {
1117    type Output = E;
1118    fn neg(self) -> E {
1119        E::new(Expr::Neg(self)).simplify()
1120    }
1121}
1122
1123// --- Mixed ops: E with f64 (auto-simplify) ---
1124
1125impl std::ops::Add<f64> for E {
1126    type Output = E;
1127    fn add(self, rhs: f64) -> E { E::new(Expr::Add(self, constant(rhs))).simplify() }
1128}
1129
1130impl std::ops::Add<E> for f64 {
1131    type Output = E;
1132    fn add(self, rhs: E) -> E { E::new(Expr::Add(constant(self), rhs)).simplify() }
1133}
1134
1135impl std::ops::Sub<f64> for E {
1136    type Output = E;
1137    fn sub(self, rhs: f64) -> E { E::new(Expr::Sub(self, constant(rhs))).simplify() }
1138}
1139
1140impl std::ops::Sub<E> for f64 {
1141    type Output = E;
1142    fn sub(self, rhs: E) -> E { E::new(Expr::Sub(constant(self), rhs)).simplify() }
1143}
1144
1145impl std::ops::Mul<f64> for E {
1146    type Output = E;
1147    fn mul(self, rhs: f64) -> E { E::new(Expr::Mul(self, constant(rhs))).simplify() }
1148}
1149
1150impl std::ops::Mul<E> for f64 {
1151    type Output = E;
1152    fn mul(self, rhs: E) -> E { E::new(Expr::Mul(constant(self), rhs)).simplify() }
1153}
1154
1155impl std::ops::Div<f64> for E {
1156    type Output = E;
1157    fn div(self, rhs: f64) -> E { E::new(Expr::Div(self, constant(rhs))).simplify() }
1158}
1159
1160impl std::ops::Div<E> for f64 {
1161    type Output = E;
1162    fn div(self, rhs: E) -> E { E::new(Expr::Div(constant(self), rhs)).simplify() }
1163}
1164
1165// --- Mixed ops: E with i64 (auto-simplify).
1166//
1167// The pure-E and E-with-f64 impls above already cover the common
1168// cases. These i64 impls let bare integer literals (`2 * x`,
1169// `x + 1`) work without an explicit `.0` suffix: Rust's type
1170// inference picks i64 when no concrete type is pinned, and
1171// integer literals with type annotations (`2i64 * x`) also flow
1172// through here. We convert to f64 at construction time to keep
1173// the expression tree representation uniform.
1174
1175impl std::ops::Add<i64> for E {
1176    type Output = E;
1177    fn add(self, rhs: i64) -> E { E::new(Expr::Add(self, constant(rhs as f64))).simplify() }
1178}
1179
1180impl std::ops::Add<E> for i64 {
1181    type Output = E;
1182    fn add(self, rhs: E) -> E { E::new(Expr::Add(constant(self as f64), rhs)).simplify() }
1183}
1184
1185impl std::ops::Sub<i64> for E {
1186    type Output = E;
1187    fn sub(self, rhs: i64) -> E { E::new(Expr::Sub(self, constant(rhs as f64))).simplify() }
1188}
1189
1190impl std::ops::Sub<E> for i64 {
1191    type Output = E;
1192    fn sub(self, rhs: E) -> E { E::new(Expr::Sub(constant(self as f64), rhs)).simplify() }
1193}
1194
1195impl std::ops::Mul<i64> for E {
1196    type Output = E;
1197    fn mul(self, rhs: i64) -> E { E::new(Expr::Mul(self, constant(rhs as f64))).simplify() }
1198}
1199
1200impl std::ops::Mul<E> for i64 {
1201    type Output = E;
1202    fn mul(self, rhs: E) -> E { E::new(Expr::Mul(constant(self as f64), rhs)).simplify() }
1203}
1204
1205impl std::ops::Div<i64> for E {
1206    type Output = E;
1207    fn div(self, rhs: i64) -> E { E::new(Expr::Div(self, constant(rhs as f64))).simplify() }
1208}
1209
1210impl std::ops::Div<E> for i64 {
1211    type Output = E;
1212    fn div(self, rhs: E) -> E { E::new(Expr::Div(constant(self as f64), rhs)).simplify() }
1213}
1214
1215// --- Custom function support ---
1216
1217/// Expand a Func node by substituting params -> args in the body.
1218pub(crate) fn expand_func(params: &[String], body: &E, args: &[E]) -> E {
1219    let mut expanded = body.clone();
1220    for (p, a) in params.iter().zip(args.iter()) {
1221        expanded = expanded.subs(p, a);
1222    }
1223    expanded
1224}
1225
1226/// Create a unary custom function. Returns a closure usable as `f(expr)`.
1227/// Codegen inlines the expanded body.
1228///
1229/// # Example
1230/// ```
1231/// use arael_sym::*;
1232/// sym! {
1233///     let square = simple_func1("square", |t| t * t);
1234///     let x = symbol("x");
1235///     assert_eq!(format!("{}", square(x + 1.0)), "square(x + 1)");
1236///     assert_eq!(format!("{}", square(x).diff(x)), "2 * x");
1237/// };
1238/// ```
1239pub fn simple_func1(name: &str, body: impl Fn(E) -> E) -> impl Fn(E) -> E + Clone {
1240    let name = name.to_string();
1241    let body = body(symbol("__p0"));
1242    move |arg: E| {
1243        E::new(Expr::Func {
1244            name: name.clone(),
1245            params: vec!["__p0".to_string()],
1246            kind: FuncKind::Symbolic { body: body.clone() },
1247            args: vec![arg],
1248        })
1249    }
1250}
1251
1252/// Create a binary custom function. Returns a closure usable as `f(a, b)`.
1253/// Codegen inlines the expanded body.
1254pub fn simple_func2(name: &str, body: impl Fn(E, E) -> E) -> impl Fn(E, E) -> E + Clone {
1255    let name = name.to_string();
1256    let body = body(symbol("__p0"), symbol("__p1"));
1257    move |a: E, b: E| {
1258        E::new(Expr::Func {
1259            name: name.clone(),
1260            params: vec!["__p0".to_string(), "__p1".to_string()],
1261            kind: FuncKind::Symbolic { body: body.clone() },
1262            args: vec![a, b],
1263        })
1264    }
1265}
1266
1267/// Create an n-ary custom function. Returns a closure usable as `f(vec![...])`.
1268/// Codegen inlines the expanded body.
1269pub fn simple_func(name: &str, arity: usize, body: impl Fn(Vec<E>) -> E) -> impl Fn(Vec<E>) -> E + Clone {
1270    let name = name.to_string();
1271    let params: Vec<String> = (0..arity).map(|i| format!("__p{}", i)).collect();
1272    let syms: Vec<E> = params.iter().map(|p| symbol(p)).collect();
1273    let body = body(syms);
1274    move |args: Vec<E>| {
1275        assert_eq!(args.len(), arity,
1276            "custom function '{}' expects {} args, got {}", name, arity, args.len());
1277        E::new(Expr::Func {
1278            name: name.clone(),
1279            params: params.clone(),
1280            kind: FuncKind::Symbolic { body: body.clone() },
1281            args,
1282        })
1283    }
1284}
1285
1286/// Create a unary function with explicit derivatives. Body used for eval
1287/// and codegen (inlined).
1288pub fn simple_func1_derivs(
1289    name: &str, body: impl Fn(E) -> E, derivs: impl Fn(E) -> [E; 1],
1290) -> impl Fn(E) -> E + Clone {
1291    let name = name.to_string();
1292    let p0 = symbol("__p0");
1293    let body = body(p0.clone());
1294    let d = derivs(p0);
1295    move |a: E| {
1296        E::new(Expr::Func {
1297            name: name.clone(),
1298            params: vec!["__p0".to_string()],
1299            kind: FuncKind::SymbolicDerivs { body: body.clone(), derivs: vec![d[0].clone()] },
1300            args: vec![a],
1301        })
1302    }
1303}
1304
1305/// Create a binary function with explicit derivatives. Body used for eval
1306/// and codegen (inlined).
1307///
1308/// # Example
1309/// ```
1310/// use arael_sym::*;
1311/// sym! {
1312///     // Or use the built-in safe_atan2():
1313///     let a = symbol("a");
1314///     let f = safe_atan2(sin(a), cos(a));
1315///     assert_eq!(format!("{}", f), "safe_atan2(sin(a), cos(a))");
1316/// };
1317/// ```
1318pub fn simple_func2_derivs(
1319    name: &str, body: impl Fn(E, E) -> E, derivs: impl Fn(E, E) -> [E; 2],
1320) -> impl Fn(E, E) -> E + Clone {
1321    let name = name.to_string();
1322    let p0 = symbol("__p0");
1323    let p1 = symbol("__p1");
1324    let body = body(p0.clone(), p1.clone());
1325    let d = derivs(p0, p1);
1326    move |a: E, b: E| {
1327        E::new(Expr::Func {
1328            name: name.clone(),
1329            params: vec!["__p0".to_string(), "__p1".to_string()],
1330            kind: FuncKind::SymbolicDerivs { body: body.clone(), derivs: vec![d[0].clone(), d[1].clone()] },
1331            args: vec![a, b],
1332        })
1333    }
1334}
1335
1336/// Create an n-ary function with explicit derivatives. Body used for eval
1337/// and codegen (inlined).
1338pub fn simple_func_derivs(
1339    name: &str, arity: usize, body: impl Fn(Vec<E>) -> E, derivs: impl Fn(Vec<E>) -> Vec<E>,
1340) -> impl Fn(Vec<E>) -> E + Clone {
1341    let name = name.to_string();
1342    let params: Vec<String> = (0..arity).map(|i| format!("__p{}", i)).collect();
1343    let syms: Vec<E> = params.iter().map(|p| symbol(p)).collect();
1344    let body = body(syms.clone());
1345    let d = derivs(syms);
1346    assert_eq!(d.len(), arity, "derivs must return {} elements", arity);
1347    move |args: Vec<E>| {
1348        assert_eq!(args.len(), arity,
1349            "function '{}' expects {} args, got {}", name, arity, args.len());
1350        E::new(Expr::Func {
1351            name: name.clone(),
1352            params: params.clone(),
1353            kind: FuncKind::SymbolicDerivs { body: body.clone(), derivs: d.clone() },
1354            args,
1355        })
1356    }
1357}
1358
1359/// Create a unary extern function: codegen emits `call_path(arg)`,
1360/// explicit derivatives for differentiation, `eval_fn` for eval.
1361pub fn extern_func1(
1362    name: &str, call_path: &str,
1363    derivs: impl Fn(E) -> [E; 1],
1364    eval_fn: fn(&[f64]) -> f64,
1365) -> impl Fn(E) -> E + Clone {
1366    let name = name.to_string();
1367    let call_path = call_path.to_string();
1368    let d = derivs(symbol("__p0"));
1369    move |a: E| {
1370        E::new(Expr::Func {
1371            name: name.clone(),
1372            params: vec!["__p0".to_string()],
1373            kind: FuncKind::Extern {
1374                derivs: vec![d[0].clone()],
1375                eval_fn,
1376                call_path: call_path.clone(),
1377            },
1378            args: vec![a],
1379        })
1380    }
1381}
1382
1383/// Create a binary extern function: codegen emits `call_path(a, b)`,
1384/// explicit derivatives for differentiation, `eval_fn` for eval.
1385///
1386/// Use [`grad2`] to auto-compute derivatives from a body expression.
1387///
1388/// # Example
1389/// ```
1390/// use arael_sym::*;
1391/// sym! {
1392///     let f = extern_func2("rad_diff", "arael::utils::rad_diff",
1393///         grad2(|a, b| a - b),
1394///         |args: &[f64]| args[0] - args[1]);
1395///     let (x, y) = symbols!(x, y);
1396///     assert_eq!(format!("{}", f(x, y).diff(x)), "1");
1397///     assert_eq!(f(x, y).to_rust("f64"), "arael::utils::rad_diff(x, y)");
1398/// };
1399/// ```
1400pub fn extern_func2(
1401    name: &str, call_path: &str,
1402    derivs: impl Fn(E, E) -> [E; 2],
1403    eval_fn: fn(&[f64]) -> f64,
1404) -> impl Fn(E, E) -> E + Clone {
1405    let name = name.to_string();
1406    let call_path = call_path.to_string();
1407    let d = derivs(symbol("__p0"), symbol("__p1"));
1408    move |a: E, b: E| {
1409        E::new(Expr::Func {
1410            name: name.clone(),
1411            params: vec!["__p0".to_string(), "__p1".to_string()],
1412            kind: FuncKind::Extern {
1413                derivs: vec![d[0].clone(), d[1].clone()],
1414                eval_fn,
1415                call_path: call_path.clone(),
1416            },
1417            args: vec![a, b],
1418        })
1419    }
1420}
1421
1422/// Create an n-ary extern function: codegen emits `call_path(args...)`,
1423/// explicit derivatives for differentiation, `eval_fn` for eval.
1424pub fn extern_func(
1425    name: &str, arity: usize, call_path: &str,
1426    derivs: impl Fn(Vec<E>) -> Vec<E>,
1427    eval_fn: fn(&[f64]) -> f64,
1428) -> impl Fn(Vec<E>) -> E + Clone {
1429    let name = name.to_string();
1430    let call_path = call_path.to_string();
1431    let params: Vec<String> = (0..arity).map(|i| format!("__p{}", i)).collect();
1432    let syms: Vec<E> = params.iter().map(|p| symbol(p)).collect();
1433    let d = derivs(syms);
1434    assert_eq!(d.len(), arity, "derivs must return {} elements", arity);
1435    move |args: Vec<E>| {
1436        assert_eq!(args.len(), arity,
1437            "extern function '{}' expects {} args, got {}", name, arity, args.len());
1438        E::new(Expr::Func {
1439            name: name.clone(),
1440            params: params.clone(),
1441            kind: FuncKind::Extern {
1442                derivs: d.clone(),
1443                eval_fn,
1444                call_path: call_path.clone(),
1445            },
1446            args,
1447        })
1448    }
1449}
1450
1451/// Compute the gradient of a unary function symbolically.
1452/// Returns a closure suitable for `simple_func1_derivs` or `extern_func1`.
1453pub fn grad1(body: impl Fn(E) -> E) -> impl Fn(E) -> [E; 1] + Clone {
1454    let p = symbol("__g0");
1455    let d = body(p).diff("__g0");
1456    move |a: E| { [d.subs("__g0", &a)] }
1457}
1458
1459/// Compute the gradient of a binary function symbolically.
1460/// Returns a closure suitable for `simple_func2_derivs` or `extern_func2`.
1461pub fn grad2(body: impl Fn(E, E) -> E) -> impl Fn(E, E) -> [E; 2] + Clone {
1462    let p0 = symbol("__g0");
1463    let p1 = symbol("__g1");
1464    let expr = body(p0, p1);
1465    let d0 = expr.diff("__g0");
1466    let d1 = expr.diff("__g1");
1467    move |a: E, b: E| {
1468        [d0.subs("__g0", &a).subs("__g1", &b),
1469         d1.subs("__g0", &a).subs("__g1", &b)]
1470    }
1471}
1472
1473/// Normalize radians to [-pi, pi].
1474fn rad2rad(v: f64) -> f64 {
1475    use std::f64::consts::PI;
1476    if !(-PI..=PI).contains(&v) {
1477        v - (2.0 * PI) * (v / (2.0 * PI) + 0.5).floor()
1478    } else {
1479        v
1480    }
1481}
1482
1483/// Rollover-safe radian difference: $(a - b)$ normalized to $[-\pi, \pi]$.
1484///
1485/// Differentiation treats it as $a - b$: $\frac{\partial}{\partial a} = 1$, $\frac{\partial}{\partial b} = -1$.
1486pub fn rad_diff(a: E, b: E) -> E {
1487    extern_func2("rad_diff", "arael::utils::rad_diff",
1488        grad2(|a, b| a - b),
1489        |args: &[f64]| rad2rad(args[0] - args[1]))(a, b)
1490}
1491
1492/// Rollover-safe radian sum: $(a + b)$ normalized to $[-\pi, \pi]$.
1493///
1494/// Differentiation treats it as $a + b$: $\frac{\partial}{\partial a} = 1$, $\frac{\partial}{\partial b} = 1$.
1495pub fn rad_sum(a: E, b: E) -> E {
1496    extern_func2("rad_sum", "arael::utils::rad_sum",
1497        grad2(|a, b| a + b),
1498        |args: &[f64]| rad2rad(args[0] + args[1]))(a, b)
1499}
1500
1501/// Identity function: $\text{identity}(x) = x$, $\frac{d}{dx} = 1$.
1502///
1503/// The simplifier does not look inside Func nodes, so `identity(a - b)`
1504/// prevents term reordering across the boundary. Codegen wraps the inlined
1505/// body in parentheses to preserve evaluation order.
1506///
1507/// Use this to guard expressions against floating-point cancellation.
1508/// For example, $\text{identity}(1 - x^2) + \epsilon^2$ ensures
1509/// the subtraction evaluates first, then $\epsilon^2$ is added to the result.
1510pub fn identity(x: E) -> E {
1511    simple_func1("identity", |t| t)(x)
1512}
1513
1514/// Safe atan2 with non-diverging derivatives.
1515///
1516/// $$\text{atan2\\_safe}(y, x) = \text{atan2}(y, x)$$
1517///
1518/// $$\frac{\partial}{\partial y} = \frac{x}{x^2 + y^2 + \epsilon^2}, \quad
1519///   \frac{\partial}{\partial x} = \frac{-y}{x^2 + y^2 + \epsilon^2}$$
1520///
1521/// The $\epsilon^2$ term prevents division by zero at $(0, 0)$.
1522pub fn safe_atan2(y: E, x: E) -> E {
1523    simple_func2_derivs("safe_atan2",
1524        atan2,
1525        |y, x| {
1526            let eps2 = epsilon() * epsilon();
1527            let d = x.clone()*x.clone() + y.clone()*y.clone() + eps2;
1528            [x / d.clone(), -y / d]
1529        })(y, x)
1530}
1531
1532/// Safe asin with clamped domain and non-diverging derivative.
1533///
1534/// $$\text{asin\\_safe}(x) = \arcsin(\text{clamp}(x, -1, 1))$$
1535///
1536/// $$\frac{d}{dx} = \frac{1}{\sqrt{\text{identity}(1 - x^2) + \epsilon^2}}$$
1537///
1538/// The [`identity`] guard prevents the simplifier from reordering
1539/// $1 - x^2$ and $\epsilon^2$, avoiding floating-point cancellation.
1540pub fn safe_asin(x: E) -> E {
1541    simple_func1_derivs("safe_asin",
1542        |x| asin(clamp(x, c(-1.0), c(1.0))),
1543        // Clamp x to [-1, 1] in the derivative too, so eval at |x| > 1
1544        // gives a finite value (1 / epsilon) instead of NaN. The body's
1545        // clamp keeps `asin` inside its domain; this one keeps the
1546        // derivative's `1 - x^2` non-negative.
1547        |x| {
1548            let xc = clamp(x, c(-1.0), c(1.0));
1549            [c(1.0) / sqrt(identity(c(1.0) - xc.clone()*xc) + epsilon()*epsilon())]
1550        }
1551    )(x)
1552}
1553
1554/// Safe acos with clamped domain and non-diverging derivative.
1555///
1556/// $$\text{acos\\_safe}(x) = \arccos(\text{clamp}(x, -1, 1))$$
1557///
1558/// $$\frac{d}{dx} = \frac{-1}{\sqrt{\text{identity}(1 - x^2) + \epsilon^2}}$$
1559pub fn safe_acos(x: E) -> E {
1560    simple_func1_derivs("safe_acos",
1561        |x| acos(clamp(x, c(-1.0), c(1.0))),
1562        // Same fix as `safe_asin`: clamp x in the derivative so
1563        // `1 - x^2` stays non-negative for any input.
1564        |x| {
1565            let xc = clamp(x, c(-1.0), c(1.0));
1566            [-c(1.0) / sqrt(identity(c(1.0) - xc.clone()*xc) + epsilon()*epsilon())]
1567        }
1568    )(x)
1569}
1570
1571/// Safe square root: clamps negative inputs to zero, non-diverging derivative.
1572///
1573/// $$\text{safe\_sqrt}(x) = \sqrt{\max(x, 0)}$$
1574///
1575/// $$\frac{d}{dx} = \frac{1}{2\sqrt{x + \epsilon^2}}$$
1576///
1577/// Negative inputs evaluate as zero. The runtime function asserts if the input
1578/// is more than noise-level negative. The $\epsilon^2$ term prevents the
1579/// derivative from diverging at $x = 0$.
1580pub fn safe_sqrt(x: E) -> E {
1581    extern_func1("safe_sqrt", "arael::utils::safe_sqrt",
1582        // Guard the derivative's `x` against negative inputs so
1583        // `sqrt(x + eps^2)` stays defined. `heaviside(x)` folds
1584        // negative x to zero; at eval time that gives `0.5 / eps`,
1585        // finite and large, instead of NaN.
1586        |x| [c(0.5) / sqrt(identity(x.clone() * heaviside(x)) + epsilon()*epsilon())],
1587        |args| {
1588            let v = args[0];
1589            if v <= 0.0 { 0.0 } else { v.sqrt() }
1590        }
1591    )(x)
1592}
1593
1594// Re-export linalg types
1595pub use linalg::{SymVec, SymMat, jacobian};
1596pub use parse::{parse, parse_with_functions, ParseError};
1597pub use geo::{vect2sym, vect3sym, matrix2sym, matrix3sym, quaternsym};
1598pub use cse::cse;
1599pub use arael_sym_macros::sym;
1600
1601#[cfg(test)]
1602mod tests {
1603    use super::*;
1604    use std::collections::HashMap;
1605
1606    #[test]
1607    fn simple_func_identity_display() {
1608        sym! {
1609            let identity = simple_func1("identity", |t| t);
1610            let x = symbol("x");
1611            assert_eq!(format!("{}", identity(x)), "identity(x)");
1612        }
1613    }
1614
1615    #[test]
1616    fn simple_func_identity_diff() {
1617        sym! {
1618            let identity = simple_func1("identity", |t| t);
1619            let x = symbol("x");
1620            let f = identity(x);
1621            assert_eq!(format!("{}", f.diff("x")), "1");
1622        }
1623    }
1624
1625    #[test]
1626    fn simple_func_identity_chain_rule() {
1627        sym! {
1628            let identity = simple_func1("identity", |t| t);
1629            let x = symbol("x");
1630            let f = identity(x * x);
1631            assert_eq!(format!("{}", f.diff("x")), "2 * x");
1632        }
1633    }
1634
1635    #[test]
1636    fn simple_func_identity_eval() {
1637        sym! {
1638            let identity = simple_func1("identity", |t| t);
1639            let x = symbol("x");
1640            let f = identity(x);
1641            let vars = HashMap::from([("x", 5.0)]);
1642            assert_eq!(f.eval(&vars).unwrap(), 5.0);
1643        }
1644    }
1645
1646    #[test]
1647    fn simple_func_square() {
1648        sym! {
1649            let square = simple_func1("square", |t| t * t);
1650            let x = symbol("x");
1651            let f = square(x + 1.0);
1652            assert_eq!(format!("{}", f), "square(x + 1)");
1653            assert_eq!(format!("{}", f.diff("x")), "2 * (x + 1)");
1654        }
1655    }
1656
1657    #[test]
1658    fn simple_func_square_eval() {
1659        sym! {
1660            let square = simple_func1("square", |t| t * t);
1661            let x = symbol("x");
1662            let f = square(x);
1663            let vars = HashMap::from([("x", 4.0)]);
1664            assert_eq!(f.eval(&vars).unwrap(), 16.0);
1665        }
1666    }
1667
1668    #[test]
1669    fn simple_func_binary() {
1670        sym! {
1671            let f = simple_func2("prod", |a, b| a * b);
1672            let x = symbol("x");
1673            let y = symbol("y");
1674            let result = f(x, y);
1675            assert_eq!(format!("{}", result), "prod(x, y)");
1676            assert_eq!(format!("{}", result.diff("x")), "y");
1677            assert_eq!(format!("{}", result.diff("y")), "x");
1678        }
1679    }
1680
1681    #[test]
1682    fn simple_func_nested() {
1683        sym! {
1684            let identity = simple_func1("identity", |t| t);
1685            let square = simple_func1("square", |t| t * t);
1686            let x = symbol("x");
1687            let f = identity(square(x));
1688            assert_eq!(format!("{}", f), "identity(square(x))");
1689            assert_eq!(format!("{}", f.diff("x")), "2 * x");
1690        }
1691    }
1692
1693    #[test]
1694    fn simple_func_my_sin() {
1695        sym! {
1696            let my_sin = simple_func1("my_sin", |t| sin(t));
1697            let x = symbol("x");
1698            let f = my_sin(x);
1699            assert_eq!(format!("{}", f), "my_sin(x)");
1700            assert_eq!(format!("{}", f.diff("x")), "cos(x)");
1701        }
1702    }
1703
1704    #[test]
1705    fn simple_func_my_sin_chain_rule() {
1706        sym! {
1707            let my_sin = simple_func1("my_sin", |t| sin(t));
1708            let x = symbol("x");
1709            let f = my_sin(x * x);
1710            assert_eq!(format!("{}", f.diff("x")), "2 * x * cos(x^2)");
1711        }
1712    }
1713
1714    #[test]
1715    fn simple_func_to_rust() {
1716        sym! {
1717            let identity = simple_func1("identity", |t| t);
1718            let x = symbol("x");
1719            let f = identity(x);
1720            assert_eq!(f.to_rust("f64"), "x");
1721        }
1722    }
1723
1724    #[test]
1725    fn simple_func_latex() {
1726        sym! {
1727            let identity = simple_func1("identity", |t| t);
1728            let x = symbol("x");
1729            let f = identity(x);
1730            assert_eq!(f.to_latex(), "\\operatorname{identity}\\left(x\\right)");
1731        }
1732    }
1733
1734    #[test]
1735    fn simple_func_free_vars() {
1736        sym! {
1737            let identity = simple_func1("identity", |t| t);
1738            let x = symbol("x");
1739            let f = identity(x + symbol("y"));
1740            let vars = f.free_vars();
1741            assert!(vars.contains("x"));
1742            assert!(vars.contains("y"));
1743            assert!(!vars.contains("t"));
1744        }
1745    }
1746
1747    #[test]
1748    fn simple_func_subs() {
1749        sym! {
1750            let identity = simple_func1("identity", |t| t);
1751            let x = symbol("x");
1752            let f = identity(x);
1753            let g = f.subs("x", &constant(3.0));
1754            assert_eq!(format!("{}", g), "identity(3)");
1755        }
1756    }
1757
1758    #[test]
1759    fn simple_func_simplify_constants() {
1760        sym! {
1761            let square = simple_func1("square", |t| t * t);
1762            let f = square(constant(3.0));
1763            let s = f.simplify();
1764            assert_eq!(format!("{}", s), "9");
1765        }
1766    }
1767
1768    #[test]
1769    fn simple_func_nary() {
1770        sym! {
1771            let f = simple_func("triple_sum", 3, |v| v[0].clone() + v[1].clone() + v[2].clone());
1772            let x = symbol("x");
1773            let y = symbol("y");
1774            let z = symbol("z");
1775            let result = f(vec![x, y, z]);
1776            assert_eq!(format!("{}", result), "triple_sum(x, y, z)");
1777            assert_eq!(format!("{}", result.diff("x")), "1");
1778        }
1779    }
1780
1781    #[test]
1782    fn simple_func_expand() {
1783        sym! {
1784            let square = simple_func1("square", |t| t * t);
1785            let x = symbol("x");
1786            let f = square(x + 1.0);
1787            let expanded = f.expand();
1788            assert_eq!(format!("{}", expanded), "x^2 + 2 * x + 1");
1789        }
1790    }
1791
1792    // --- Simple func derivs tests ---
1793
1794    #[test]
1795    fn simple_func_derivs_codegen() {
1796        // codegen should inline the body, not the derivs
1797        sym! {
1798            let f = simple_func1_derivs("inv", |t| 1.0 / t, |t| [-1.0 / (t * t)]);
1799            let x = symbol("x");
1800            assert_eq!(f(x).to_rust("f64"), "1.0_f64 / x");
1801        }
1802    }
1803
1804    // --- Safe function tests ---
1805
1806    #[test]
1807    fn safe_atan2_diff() {
1808        sym! {
1809            let a = symbol("a");
1810            let b = symbol("b");
1811            let f = safe_atan2(a, b);
1812            let da = f.diff("a");
1813            let vars = HashMap::from([("a", 1.0), ("b", 1.0)]);
1814            let v = da.eval(&vars).unwrap();
1815            assert!((v - 0.5).abs() < 1e-10, "d/da at (1,1) = {}, expected 0.5", v);
1816        }
1817    }
1818
1819    #[test]
1820    fn safe_atan2_eval() {
1821        sym! {
1822            let a = symbol("a");
1823            let b = symbol("b");
1824            let f = safe_atan2(a, b);
1825            let vars = HashMap::from([("a", 1.0), ("b", 1.0)]);
1826            let v = f.eval(&vars).unwrap();
1827            assert!((v - std::f64::consts::FRAC_PI_4).abs() < 1e-10);
1828        }
1829    }
1830
1831    #[test]
1832    fn safe_atan2_chain_rule() {
1833        sym! {
1834            let t = symbol("t");
1835            let f = safe_atan2(sin(t), cos(t));
1836            let df = f.diff("t");
1837            let vars = HashMap::from([("t", 0.5)]);
1838            let v = df.eval(&vars).unwrap();
1839            assert!((v - 1.0).abs() < 1e-8, "df/dt at t=0.5 = {}, expected 1", v);
1840        }
1841    }
1842
1843    #[test]
1844    fn safe_atan2_at_zero() {
1845        sym! {
1846            let a = symbol("a");
1847            let b = symbol("b");
1848            let da = safe_atan2(a, b).diff("a");
1849            let vars = HashMap::from([("a", 0.0), ("b", 0.0)]);
1850            let v = da.eval(&vars).unwrap();
1851            assert!(v.is_finite(), "derivative at (0,0) should be finite, got {}", v);
1852        }
1853    }
1854
1855    #[test]
1856    fn safe_asin_eval() {
1857        sym! {
1858            let x = symbol("x");
1859            let f = safe_asin(x);
1860            // Normal value
1861            let vars = HashMap::from([("x", 0.5)]);
1862            assert!((f.eval(&vars).unwrap() - 0.5_f64.asin()).abs() < 1e-10);
1863            // Clamped: safe_asin(1.5) = asin(1.0) = pi/2
1864            let vars = HashMap::from([("x", 1.5)]);
1865            assert!((f.eval(&vars).unwrap() - std::f64::consts::FRAC_PI_2).abs() < 1e-10);
1866        }
1867    }
1868
1869    #[test]
1870    fn safe_asin_deriv_finite() {
1871        sym! {
1872            let x = symbol("x");
1873            let da = safe_asin(x).diff("x");
1874            // At x=1.0, vanilla asin derivative diverges; safe version stays finite
1875            let vars = HashMap::from([("x", 1.0)]);
1876            let v = da.eval(&vars).unwrap();
1877            assert!(v.is_finite(), "safe_asin derivative at 1.0 should be finite, got {}", v);
1878        }
1879    }
1880
1881    #[test]
1882    fn safe_acos_eval() {
1883        sym! {
1884            let x = symbol("x");
1885            let f = safe_acos(x);
1886            let vars = HashMap::from([("x", 0.5)]);
1887            assert!((f.eval(&vars).unwrap() - 0.5_f64.acos()).abs() < 1e-10);
1888            // Clamped: safe_acos(-1.5) = acos(-1.0) = pi
1889            let vars = HashMap::from([("x", -1.5)]);
1890            assert!((f.eval(&vars).unwrap() - std::f64::consts::PI).abs() < 1e-10);
1891        }
1892    }
1893
1894    #[test]
1895    fn identity_codegen_parens() {
1896        sym! {
1897            let x = symbol("x");
1898            let f = identity(c(1.0) - x * x) + epsilon * epsilon;
1899            let code = f.to_rust("f64");
1900            // identity forces parens around its body
1901            assert!(code.contains("(-x.powf(2.0_f64) + 1.0_f64)"),
1902                "expected parens around identity body, got: {}", code);
1903        }
1904    }
1905
1906    #[test]
1907    fn identity_diff() {
1908        sym! {
1909            let x = symbol("x");
1910            let f = identity(x * x);
1911            assert_eq!(format!("{}", f.diff("x")), "2 * x");
1912        }
1913    }
1914
1915    #[test]
1916    fn safe_acos_deriv_finite() {
1917        sym! {
1918            let x = symbol("x");
1919            let da = safe_acos(x).diff("x");
1920            let vars = HashMap::from([("x", 1.0)]);
1921            let v = da.eval(&vars).unwrap();
1922            assert!(v.is_finite(), "safe_acos derivative at 1.0 should be finite, got {}", v);
1923        }
1924    }
1925
1926    /// Regression: the derivative of `safe_asin` / `safe_acos` /
1927    /// `safe_sqrt` must stay finite even for inputs well outside the
1928    /// safe domain. Previously each derivative formula used the raw
1929    /// `x` unclamped, so for `|x| > 1` (asin/acos) or `x < 0` (sqrt)
1930    /// the inner `sqrt(1 - x^2 + eps^2)` / `sqrt(x + eps^2)` evaluated
1931    /// a negative operand and produced NaN.
1932    #[test]
1933    fn safe_derivs_finite_outside_domain() {
1934        sym! {
1935            let x = symbol("x");
1936            let d_asin = safe_asin(x).diff("x");
1937            let d_acos = safe_acos(x).diff("x");
1938            let d_sqrt = safe_sqrt(x).diff("x");
1939            for v in [-5.0_f64, -1.5, 1.5, 5.0] {
1940                let vars = HashMap::from([("x", v)]);
1941                let a = d_asin.eval(&vars).unwrap();
1942                let c = d_acos.eval(&vars).unwrap();
1943                assert!(a.is_finite(), "safe_asin'({}) should be finite, got {}", v, a);
1944                assert!(c.is_finite(), "safe_acos'({}) should be finite, got {}", v, c);
1945            }
1946            for v in [-5.0_f64, -1.0, -1e-12, 0.0] {
1947                let vars = HashMap::from([("x", v)]);
1948                let s = d_sqrt.eval(&vars).unwrap();
1949                assert!(s.is_finite(), "safe_sqrt'({}) should be finite, got {}", v, s);
1950            }
1951        }
1952    }
1953
1954    #[test]
1955    fn safe_sqrt_eval() {
1956        sym! {
1957            let x = symbol("x");
1958            let f = safe_sqrt(x);
1959            let vars = HashMap::from([("x", 4.0)]);
1960            assert!((f.eval(&vars).unwrap() - 2.0).abs() < 1e-10);
1961            // Negative input: safe_sqrt(-1e-10) = 0 (clamped)
1962            let vars = HashMap::from([("x", -1e-10)]);
1963            assert!(f.eval(&vars).unwrap().abs() < 1e-10);
1964            // Zero: safe_sqrt(0) = 0
1965            let vars = HashMap::from([("x", 0.0)]);
1966            assert!(f.eval(&vars).unwrap().abs() < 1e-10);
1967        }
1968    }
1969
1970    #[test]
1971    fn safe_sqrt_deriv_at_zero() {
1972        sym! {
1973            let x = symbol("x");
1974            let df = safe_sqrt(x).diff("x");
1975            // At x=0, vanilla sqrt derivative diverges; safe version stays finite
1976            let vars = HashMap::from([("x", 0.0)]);
1977            let v = df.eval(&vars).unwrap();
1978            assert!(v.is_finite(), "safe_sqrt derivative at 0 should be finite, got {}", v);
1979        }
1980    }
1981
1982    // --- Grad helper tests ---
1983
1984    #[test]
1985    fn grad2_basic() {
1986        sym! {
1987            let g = grad2(|a, b| a * b);
1988            let x = symbol("x");
1989            let y = symbol("y");
1990            let [da, db] = g(x, y);
1991            assert_eq!(format!("{}", da), "y");
1992            assert_eq!(format!("{}", db), "x");
1993        }
1994    }
1995
1996    #[test]
1997    fn grad1_basic() {
1998        sym! {
1999            let g = grad1(|t| t * t);
2000            let x = symbol("x");
2001            let [dt] = g(x);
2002            assert_eq!(format!("{}", dt), "2 * x");
2003        }
2004    }
2005
2006    // --- Extern function tests ---
2007
2008    #[test]
2009    fn extern_func_display() {
2010        sym! {
2011            let x = symbol("x");
2012            let y = symbol("y");
2013            let f = rad_diff(x, y);
2014            assert_eq!(format!("{}", f), "rad_diff(x, y)");
2015        }
2016    }
2017
2018    #[test]
2019    fn extern_func_diff() {
2020        sym! {
2021            let x = symbol("x");
2022            let y = symbol("y");
2023            let f = rad_diff(x, y);
2024            assert_eq!(format!("{}", f.diff("x")), "1");
2025            assert_eq!(format!("{}", f.diff("y")), "-1");
2026        }
2027    }
2028
2029    #[test]
2030    fn extern_func_chain_rule() {
2031        sym! {
2032            let x = symbol("x");
2033            let y = symbol("y");
2034            let f = rad_diff(x * x, y);
2035            assert_eq!(format!("{}", f.diff("x")), "2 * x");
2036        }
2037    }
2038
2039    #[test]
2040    fn extern_func_eval() {
2041        // For small angles, rad_diff(a,b) = a - b (no wrapping needed)
2042        sym! {
2043            let x = symbol("x");
2044            let y = symbol("y");
2045            let f = rad_diff(x, y);
2046            let vars = HashMap::from([("x", 0.3), ("y", 0.1)]);
2047            let v = f.eval(&vars).unwrap();
2048            assert!((v - 0.2).abs() < 1e-10);
2049        }
2050    }
2051
2052    #[test]
2053    fn extern_func_eval_wrapping() {
2054        // rad_diff(0, 2*pi) should be 0 (wrapping)
2055        sym! {
2056            let x = symbol("x");
2057            let f = rad_diff(constant(0.0), x);
2058            let vars = HashMap::from([("x", 2.0 * std::f64::consts::PI)]);
2059            let v = f.eval(&vars).unwrap();
2060            assert!(v.abs() < 1e-10, "rad_diff(0, 2*pi) = {}, expected 0", v);
2061        }
2062    }
2063
2064    #[test]
2065    fn extern_func_to_rust() {
2066        sym! {
2067            let x = symbol("x");
2068            let y = symbol("y");
2069            let f = rad_diff(x, y);
2070            let code = f.to_rust("f64");
2071            assert_eq!(code, "arael::utils::rad_diff(x, y)");
2072        }
2073    }
2074
2075    #[test]
2076    fn extern_func_latex() {
2077        sym! {
2078            let x = symbol("x");
2079            let y = symbol("y");
2080            let f = rad_diff(x, y);
2081            assert_eq!(f.to_latex(), "\\operatorname{rad\\_diff}\\left(x, y\\right)");
2082        }
2083    }
2084
2085    #[test]
2086    fn extern_func_subs() {
2087        sym! {
2088            let x = symbol("x");
2089            let y = symbol("y");
2090            let f = rad_diff(x, y);
2091            let g = f.subs("x", &constant(1.0));
2092            assert_eq!(format!("{}", g), "rad_diff(1, y)");
2093        }
2094    }
2095
2096    #[test]
2097    fn extern_func_no_const_fold() {
2098        // Extern functions should not be constant-folded in simplify
2099        sym! {
2100            let f = rad_diff(constant(1.0), constant(2.0));
2101            let s = f.simplify();
2102            assert_eq!(format!("{}", s), "rad_diff(1, 2)");
2103        }
2104    }
2105
2106    #[test]
2107    fn extern_func_no_expand() {
2108        // Extern functions should stay opaque on expand
2109        sym! {
2110            let x = symbol("x");
2111            let y = symbol("y");
2112            let f = rad_diff(x + 1.0, y);
2113            let expanded = f.expand();
2114            assert_eq!(format!("{}", expanded), "rad_diff(x + 1, y)");
2115        }
2116    }
2117
2118    #[test]
2119    fn extern_func_free_vars() {
2120        sym! {
2121            let x = symbol("x");
2122            let y = symbol("y");
2123            let f = rad_diff(x, y);
2124            let vars = f.free_vars();
2125            assert!(vars.contains("x"));
2126            assert!(vars.contains("y"));
2127            assert!(!vars.contains("__a"));
2128            assert!(!vars.contains("__b"));
2129        }
2130    }
2131
2132    #[test]
2133    fn rad_sum_diff() {
2134        sym! {
2135            let x = symbol("x");
2136            let y = symbol("y");
2137            let f = rad_sum(x, y);
2138            assert_eq!(format!("{}", f.diff("x")), "1");
2139            assert_eq!(format!("{}", f.diff("y")), "1");
2140        }
2141    }
2142
2143    #[test]
2144    fn rad_sum_to_rust() {
2145        sym! {
2146            let x = symbol("x");
2147            let y = symbol("y");
2148            let f = rad_sum(x, y);
2149            assert_eq!(f.to_rust("f64"), "arael::utils::rad_sum(x, y)");
2150        }
2151    }
2152
2153    #[test]
2154    fn extern_func_def() {
2155        sym! {
2156            fn my_eval(args: &[f64]) -> f64 { args[0] - args[1] }
2157            let my_diff = extern_func2("my_diff", "my_mod::diff",
2158                grad2(|a, b| a - b), my_eval);
2159            let x = symbol("x");
2160            let y = symbol("y");
2161            let f = my_diff(x, y);
2162            assert_eq!(format!("{}", f), "my_diff(x, y)");
2163            assert_eq!(format!("{}", f.diff("x")), "1");
2164            assert_eq!(format!("{}", f.diff("y")), "-1");
2165            assert_eq!(f.to_rust("f64"), "my_mod::diff(x, y)");
2166        }
2167    }
2168
2169    // --- Heaviside tests ---
2170
2171    #[test]
2172    fn heaviside_eval() {
2173        let vars = HashMap::from([("x", 0.0)]);
2174        sym! {
2175            let x = symbol("x");
2176            let h = heaviside(x);
2177            assert_eq!(h.eval(&HashMap::from([("x", -1.0)])).unwrap(), 0.0);
2178            assert_eq!(h.eval(&vars).unwrap(), 1.0);
2179            assert_eq!(h.eval(&HashMap::from([("x", 3.0)])).unwrap(), 1.0);
2180        }
2181    }
2182
2183    #[test]
2184    fn heaviside_diff() {
2185        sym! {
2186            let x = symbol("x");
2187            assert_eq!(format!("{}", heaviside(x).diff("x")), "0");
2188            assert_eq!(format!("{}", heaviside(x * x - 1.0).diff("x")), "0");
2189        }
2190    }
2191
2192    #[test]
2193    fn heaviside_display() {
2194        sym! {
2195            let x = symbol("x");
2196            assert_eq!(format!("{}", heaviside(x)), "H(x)");
2197        }
2198    }
2199
2200    #[test]
2201    fn heaviside_composition_diff() {
2202        sym! {
2203            let x = symbol("x");
2204            // d/dx [H(1-x) * x^2] = 2x (H' = 0, product rule kills that term)
2205            let f = heaviside(1.0 - x) * x * x;
2206            assert_eq!(format!("{}", f.diff("x")), "2 * x * H(-x + 1)");
2207        }
2208    }
2209
2210    // --- Clamp tests ---
2211
2212    #[test]
2213    fn clamp_eval() {
2214        sym! {
2215            let x = symbol("x");
2216            let f = clamp(x, c(0.0), c(1.0));
2217            assert_eq!(f.eval(&HashMap::from([("x", 0.5)])).unwrap(), 0.5);
2218            assert_eq!(f.eval(&HashMap::from([("x", -2.0)])).unwrap(), 0.0);
2219            assert_eq!(f.eval(&HashMap::from([("x", 5.0)])).unwrap(), 1.0);
2220        }
2221    }
2222
2223    #[test]
2224    fn clamp_diff_passthrough() {
2225        sym! {
2226            let x = symbol("x");
2227            // d/dx clamp(x, 0, 1) = 1 (pass-through)
2228            assert_eq!(format!("{}", clamp(x, c(0.0), c(1.0)).diff("x")), "1");
2229            // d/dx clamp(x^2, 0, 1) = 2x (chain rule on first arg)
2230            assert_eq!(format!("{}", clamp(x * x, c(0.0), c(1.0)).diff("x")), "2 * x");
2231        }
2232    }
2233
2234    #[test]
2235    fn clamp_display() {
2236        sym! {
2237            let x = symbol("x");
2238            assert_eq!(format!("{}", clamp(x, c(0.0), c(1.0))), "clamp(x, 0, 1)");
2239        }
2240    }
2241
2242    #[test]
2243    fn clamp_simplify_constants() {
2244        sym! {
2245            let f = clamp(c(5.0), c(0.0), c(1.0));
2246            assert_eq!(format!("{}", f.simplify()), "1");
2247            let g = clamp(c(-3.0), c(0.0), c(1.0));
2248            assert_eq!(format!("{}", g.simplify()), "0");
2249            let h = clamp(c(0.5), c(0.0), c(1.0));
2250            assert_eq!(format!("{}", h.simplify()), "0.5");
2251        }
2252    }
2253
2254    // --- clamp-based safe_asin tests (simple_func1 version) ---
2255
2256    #[test]
2257    fn clamp_asin_eval() {
2258        sym! {
2259            let my_asin = simple_func1("my_asin", |t| asin(clamp(t, c(-1.0), c(1.0))));
2260            let x = symbol("x");
2261
2262            // Normal value
2263            let f = my_asin(x);
2264            let val = f.eval(&HashMap::from([("x", 0.5)])).unwrap();
2265            assert!((val - 0.5_f64.asin()).abs() < 1e-10);
2266
2267            // Out of range: no NaN
2268            let val_hi = f.eval(&HashMap::from([("x", 1.5)])).unwrap();
2269            assert!((val_hi - std::f64::consts::FRAC_PI_2).abs() < 1e-10);
2270
2271            let val_lo = f.eval(&HashMap::from([("x", -1.5)])).unwrap();
2272            assert!((val_lo + std::f64::consts::FRAC_PI_2).abs() < 1e-10);
2273        }
2274    }
2275
2276    #[test]
2277    fn clamp_asin_diff() {
2278        sym! {
2279            let my_asin = simple_func1("my_asin", |t| asin(clamp(t, c(-1.0), c(1.0))));
2280            let x = symbol("x");
2281            let f = my_asin(x);
2282            // Derivative: 1/sqrt(1 - clamp(x,-1,1)^2) * 1 (clamp pass-through)
2283            let df = f.diff("x");
2284            // Numerically verify at x=0.5
2285            let vars = HashMap::from([("x", 0.5)]);
2286            let dval = df.eval(&vars).unwrap();
2287            let expected = 1.0 / (1.0 - 0.25_f64).sqrt(); // 1/sqrt(0.75)
2288            assert!((dval - expected).abs() < 1e-10);
2289        }
2290    }
2291
2292    #[test]
2293    fn heaviside_to_rust() {
2294        sym! {
2295            let x = symbol("x");
2296            assert_eq!(heaviside(x).to_rust("f64"), "x.heaviside()");
2297        }
2298    }
2299
2300    #[test]
2301    fn clamp_to_rust() {
2302        sym! {
2303            let x = symbol("x");
2304            assert_eq!(clamp(x, c(0.0), c(1.0)).to_rust("f64"), "x.clamp(0.0_f64, 1.0_f64)");
2305        }
2306    }
2307
2308    #[test]
2309    fn parse_heaviside() {
2310        let f = parse("H(x)").unwrap();
2311        assert_eq!(format!("{}", f), "H(x)");
2312        assert_eq!(format!("{}", f.diff("x")), "0");
2313    }
2314
2315    #[test]
2316    fn parse_clamp() {
2317        let f = parse("clamp(x, 0, 1)").unwrap();
2318        assert_eq!(format!("{}", f), "clamp(x, 0, 1)");
2319        assert_eq!(format!("{}", f.diff("x")), "1");
2320    }
2321
2322    // --- Named constant tests ---
2323
2324    #[test]
2325    fn named_const_pi_display() {
2326        assert_eq!(format!("{}", pi()), "pi");
2327    }
2328
2329    #[test]
2330    fn named_const_pi_eval() {
2331        let vars = HashMap::new();
2332        assert_eq!(pi().eval(&vars).unwrap(), std::f64::consts::PI);
2333    }
2334
2335    #[test]
2336    fn named_const_pi_diff() {
2337        assert_eq!(format!("{}", pi().diff("x")), "0");
2338    }
2339
2340    #[test]
2341    fn named_const_pi_codegen() {
2342        assert_eq!(pi().to_rust("f64"), "std::f64::consts::PI");
2343        assert_eq!(pi().to_rust("f32"), "std::f32::consts::PI");
2344    }
2345
2346    #[test]
2347    fn named_const_pi_latex() {
2348        assert_eq!(pi().to_latex(), "\\pi");
2349    }
2350
2351    #[test]
2352    fn named_const_epsilon_display() {
2353        assert_eq!(format!("{}", epsilon()), "epsilon");
2354    }
2355
2356    #[test]
2357    fn named_const_epsilon_eval() {
2358        let vars = HashMap::new();
2359        assert_eq!(epsilon().eval(&vars).unwrap(), f64::EPSILON);
2360    }
2361
2362    #[test]
2363    fn named_const_epsilon_codegen() {
2364        assert_eq!(epsilon().to_rust("f64"), "f64::EPSILON");
2365        assert_eq!(epsilon().to_rust("f32"), "f32::EPSILON");
2366    }
2367
2368    #[test]
2369    fn named_const_euler_display() {
2370        assert_eq!(format!("{}", euler()), "e");
2371    }
2372
2373    #[test]
2374    fn named_const_euler_eval() {
2375        let vars = HashMap::new();
2376        assert_eq!(euler().eval(&vars).unwrap(), std::f64::consts::E);
2377    }
2378
2379    #[test]
2380    fn named_const_euler_codegen() {
2381        assert_eq!(euler().to_rust("f64"), "std::f64::consts::E");
2382    }
2383
2384    #[test]
2385    fn named_const_epsilon_survives_simplification() {
2386        sym! {
2387            let x = symbol("x");
2388            let f = (x + epsilon()).simplify();
2389            assert_eq!(format!("{}", f), "x + epsilon");
2390        }
2391    }
2392
2393    #[test]
2394    fn named_const_not_free_var() {
2395        sym! {
2396            let x = symbol("x");
2397            let f = x + pi();
2398            let vars = f.free_vars();
2399            assert!(vars.contains("x"));
2400            assert!(!vars.contains("pi"));
2401        }
2402    }
2403
2404    #[test]
2405    fn named_const_custom() {
2406        let tau = named_const("tau", std::f64::consts::TAU,
2407            "std::f32::consts::TAU", "std::f64::consts::TAU", "\\tau");
2408        assert_eq!(format!("{}", tau), "tau");
2409        let vars = HashMap::new();
2410        assert_eq!(tau.eval(&vars).unwrap(), std::f64::consts::TAU);
2411        assert_eq!(tau.to_rust("f64"), "std::f64::consts::TAU");
2412        assert_eq!(tau.to_latex(), "\\tau");
2413    }
2414
2415    // --- Algebraic simplification of named constants ---
2416
2417    #[test]
2418    fn named_const_pi_add_pi() {
2419        sym! {
2420            let f = (pi() + pi()).simplify();
2421            assert_eq!(format!("{}", f), "2 * pi");
2422        }
2423    }
2424
2425    #[test]
2426    fn named_const_pi_sub_pi() {
2427        sym! {
2428            let f = (pi() - pi()).simplify();
2429            assert_eq!(format!("{}", f), "0");
2430        }
2431    }
2432
2433    #[test]
2434    fn named_const_pi_mul_pi() {
2435        sym! {
2436            let f = (pi() * pi()).simplify();
2437            assert_eq!(format!("{}", f), "pi^2");
2438        }
2439    }
2440
2441    #[test]
2442    fn named_const_epsilon_add() {
2443        sym! {
2444            let x = symbol("x");
2445            let f = (x + epsilon() + epsilon()).simplify();
2446            assert_eq!(format!("{}", f), "x + 2 * epsilon");
2447        }
2448    }
2449
2450    // --- Trig-pi simplification ---
2451
2452    #[test]
2453    fn trig_sin_pi() {
2454        sym! { assert_eq!(format!("{}", sin(pi()).simplify()), "0"); }
2455    }
2456
2457    #[test]
2458    fn trig_cos_pi() {
2459        sym! { assert_eq!(format!("{}", cos(pi()).simplify()), "-1"); }
2460    }
2461
2462    #[test]
2463    fn trig_sin_pi_half() {
2464        sym! { assert_eq!(format!("{}", sin(pi() / 2.0).simplify()), "1"); }
2465    }
2466
2467    #[test]
2468    fn trig_cos_pi_half() {
2469        sym! { assert_eq!(format!("{}", cos(pi() / 2.0).simplify()), "0"); }
2470    }
2471
2472    #[test]
2473    fn trig_sin_pi_quarter() {
2474        sym! {
2475            let f = sin(pi() / 4.0).simplify();
2476            let vars = HashMap::new();
2477            let v = f.eval(&vars).unwrap();
2478            assert!((v - std::f64::consts::FRAC_1_SQRT_2).abs() < 1e-10);
2479        }
2480    }
2481
2482    #[test]
2483    fn trig_cos_pi_third() {
2484        sym! {
2485            let f = cos(pi() / 3.0).simplify();
2486            assert_eq!(format!("{}", f), "0.5");
2487        }
2488    }
2489
2490    #[test]
2491    fn trig_sin_2pi() {
2492        sym! { assert_eq!(format!("{}", sin(2.0 * pi()).simplify()), "0"); }
2493    }
2494
2495    #[test]
2496    fn trig_cos_2pi() {
2497        sym! { assert_eq!(format!("{}", cos(2.0 * pi()).simplify()), "1"); }
2498    }
2499
2500    #[test]
2501    fn trig_tan_pi() {
2502        sym! { assert_eq!(format!("{}", tan(pi()).simplify()), "0"); }
2503    }
2504
2505    #[test]
2506    fn trig_sin_pi_sixth() {
2507        sym! { assert_eq!(format!("{}", sin(pi() / 6.0).simplify()), "0.5"); }
2508    }
2509
2510    // --- Log/exp-e simplification ---
2511
2512    #[test]
2513    fn ln_e() {
2514        sym! { assert_eq!(format!("{}", ln(euler()).simplify()), "1"); }
2515    }
2516
2517    // --- sym! macro bare identifier tests ---
2518
2519    #[test]
2520    fn sym_macro_bare_pi() {
2521        sym! {
2522            let x = symbol("x");
2523            let f = 2.0 * pi * x;
2524            assert_eq!(format!("{}", f), "2 * x * pi");
2525        }
2526    }
2527
2528    #[test]
2529    fn sym_macro_bare_epsilon() {
2530        sym! {
2531            let x = symbol("x");
2532            let f = x * x + epsilon;
2533            assert_eq!(format!("{}", f), "x^2 + epsilon");
2534        }
2535    }
2536
2537    #[test]
2538    fn sym_macro_pi_call_still_works() {
2539        // pi() with parens should also work (not double-rewritten)
2540        sym! {
2541            let f = pi();
2542            assert_eq!(format!("{}", f), "pi");
2543        }
2544    }
2545
2546    #[test]
2547    fn ln_e_pow_x() {
2548        sym! {
2549            let x = symbol("x");
2550            let f = ln(pow(euler(), x)).simplify();
2551            assert_eq!(format!("{}", f), "x");
2552        }
2553    }
2554}
2555