#![allow(clippy::should_implement_trait)]
mod diff;
mod eval;
mod fmt;
mod simplify;
mod linalg;
mod parse;
pub mod geo;
pub mod cse;
use std::hash::{Hash, Hasher};
use std::rc::Rc;
#[derive(Clone, PartialEq)]
pub struct E(Rc<Expr>);
impl Eq for E {}
impl E {
fn new(expr: Expr) -> E {
E(Rc::new(expr))
}
pub fn symbols(&self) -> std::collections::HashSet<String> {
let mut out = std::collections::HashSet::new();
self.collect_symbols(&mut out);
out
}
fn collect_symbols(&self, out: &mut std::collections::HashSet<String>) {
match &*self.0 {
Expr::Sym(s) => { out.insert(s.clone()); }
Expr::Const(_) | Expr::NamedConst { .. } => {}
Expr::Neg(a) | Expr::Sin(a) | Expr::Cos(a) | Expr::Tan(a)
| Expr::Asin(a) | Expr::Acos(a) | Expr::Atan(a)
| Expr::Sinh(a) | Expr::Cosh(a) | Expr::Tanh(a)
| Expr::Exp(a) | Expr::Ln(a) | Expr::Log2(a) | Expr::Log10(a)
| Expr::Sqrt(a) | Expr::Abs(a)
| Expr::Heaviside(a) => { a.collect_symbols(out); }
Expr::Add(a, b) | Expr::Sub(a, b) | Expr::Mul(a, b)
| Expr::Div(a, b) | Expr::Pow(a, b) | Expr::Atan2(a, b) => {
a.collect_symbols(out);
b.collect_symbols(out);
}
Expr::Clamp(a, b, c) => {
a.collect_symbols(out);
b.collect_symbols(out);
c.collect_symbols(out);
}
Expr::Func { args, .. } => {
for arg in args { arg.collect_symbols(out); }
}
}
}
pub fn substitute(&self, subs: &[(E, E)]) -> E {
for (from, to) in subs {
if self == from { return to.clone(); }
}
match &*self.0 {
Expr::Sym(_) | Expr::Const(_) | Expr::NamedConst { .. } => self.clone(),
Expr::Neg(a) => -a.substitute(subs),
Expr::Add(a, b) => a.substitute(subs) + b.substitute(subs),
Expr::Sub(a, b) => a.substitute(subs) - b.substitute(subs),
Expr::Mul(a, b) => a.substitute(subs) * b.substitute(subs),
Expr::Div(a, b) => a.substitute(subs) / b.substitute(subs),
Expr::Pow(a, b) => pow(a.substitute(subs), b.substitute(subs)),
Expr::Sin(a) => sin(a.substitute(subs)),
Expr::Cos(a) => cos(a.substitute(subs)),
Expr::Tan(a) => tan(a.substitute(subs)),
Expr::Asin(a) => asin(a.substitute(subs)),
Expr::Acos(a) => acos(a.substitute(subs)),
Expr::Atan(a) => atan(a.substitute(subs)),
Expr::Atan2(a, b) => atan2(a.substitute(subs), b.substitute(subs)),
Expr::Sinh(a) => sinh(a.substitute(subs)),
Expr::Cosh(a) => cosh(a.substitute(subs)),
Expr::Tanh(a) => tanh(a.substitute(subs)),
Expr::Exp(a) => exp(a.substitute(subs)),
Expr::Ln(a) => ln(a.substitute(subs)),
Expr::Log2(a) => log2(a.substitute(subs)),
Expr::Log10(a) => ln(a.substitute(subs)) / ln(constant(10.0)),
Expr::Sqrt(a) => sqrt(a.substitute(subs)),
Expr::Abs(a) => abs(a.substitute(subs)),
Expr::Heaviside(a) => heaviside(a.substitute(subs)),
Expr::Clamp(a, lo, hi) => clamp(a.substitute(subs), lo.substitute(subs), hi.substitute(subs)),
Expr::Func { name, params, kind, args } => {
let new_args = args.iter().map(|a| a.substitute(subs)).collect();
E::new(Expr::Func { name: name.clone(), params: params.clone(), kind: kind.clone(), args: new_args })
}
}
}
}
impl std::ops::Deref for E {
type Target = Expr;
fn deref(&self) -> &Expr {
&self.0
}
}
impl AsRef<Expr> for E {
fn as_ref(&self) -> &Expr {
&self.0
}
}
#[derive(Debug, Clone, PartialEq)]
pub enum Expr {
Sym(String),
Const(f64),
Neg(E),
Add(E, E),
Sub(E, E),
Mul(E, E),
Div(E, E),
Pow(E, E),
Sin(E),
Cos(E),
Tan(E),
Asin(E),
Acos(E),
Atan(E),
Atan2(E, E),
Sinh(E),
Cosh(E),
Tanh(E),
Exp(E),
Ln(E),
Log2(E),
Log10(E),
Sqrt(E),
Abs(E),
Heaviside(E),
Clamp(E, E, E),
NamedConst {
name: String,
value: f64,
rust_f32: String,
rust_f64: String,
latex: String,
},
Func {
name: String,
params: Vec<String>,
kind: FuncKind,
args: Vec<E>,
},
}
#[derive(Debug, Clone, PartialEq)]
#[allow(unpredictable_function_pointer_comparisons)]
pub enum FuncKind {
Symbolic { body: E },
SymbolicDerivs { body: E, derivs: Vec<E> },
Extern { derivs: Vec<E>, eval_fn: fn(&[f64]) -> f64, call_path: String },
}
impl FuncKind {
pub fn auto_diff_body(&self) -> Option<&E> {
match self {
FuncKind::Symbolic { body } => Some(body),
_ => None,
}
}
pub fn derivs(&self) -> Option<&[E]> {
match self {
FuncKind::SymbolicDerivs { derivs, .. } | FuncKind::Extern { derivs, .. } => Some(derivs),
FuncKind::Symbolic { .. } => None,
}
}
pub fn body(&self) -> Option<&E> {
match self {
FuncKind::Symbolic { body } | FuncKind::SymbolicDerivs { body, .. } => Some(body),
FuncKind::Extern { .. } => None,
}
}
pub fn eval_fn(&self) -> Option<fn(&[f64]) -> f64> {
match self {
FuncKind::Extern { eval_fn, .. } => Some(*eval_fn),
_ => None,
}
}
}
impl Hash for FuncKind {
fn hash<H: Hasher>(&self, state: &mut H) {
std::mem::discriminant(self).hash(state);
match self {
FuncKind::Symbolic { body } => body.hash(state),
FuncKind::SymbolicDerivs { body, derivs } => {
body.hash(state);
derivs.hash(state);
}
FuncKind::Extern { derivs, eval_fn, call_path } => {
derivs.hash(state);
(*eval_fn as usize).hash(state);
call_path.hash(state);
}
}
}
}
impl Eq for Expr {}
impl Hash for Expr {
fn hash<H: Hasher>(&self, state: &mut H) {
std::mem::discriminant(self).hash(state);
match self {
Expr::Sym(s) => s.hash(state),
Expr::Const(v) => v.to_bits().hash(state),
Expr::Neg(a) | Expr::Sin(a) | Expr::Cos(a) | Expr::Tan(a)
| Expr::Asin(a) | Expr::Acos(a) | Expr::Atan(a)
| Expr::Sinh(a) | Expr::Cosh(a) | Expr::Tanh(a)
| Expr::Exp(a) | Expr::Ln(a) | Expr::Log2(a) | Expr::Log10(a)
| Expr::Sqrt(a) | Expr::Abs(a)
| Expr::Heaviside(a) => a.hash(state),
Expr::Add(a, b) | Expr::Sub(a, b) | Expr::Mul(a, b)
| Expr::Div(a, b) | Expr::Pow(a, b) | Expr::Atan2(a, b) => {
a.hash(state);
b.hash(state);
}
Expr::Clamp(a, b, c) => {
a.hash(state);
b.hash(state);
c.hash(state);
}
Expr::NamedConst { name, value, .. } => {
name.hash(state);
value.to_bits().hash(state);
}
Expr::Func { name, params, kind, args } => {
name.hash(state);
params.hash(state);
kind.hash(state);
args.hash(state);
}
}
}
}
impl Hash for E {
fn hash<H: Hasher>(&self, state: &mut H) {
self.0.hash(state);
}
}
pub fn symbol(name: &str) -> E {
E::new(Expr::Sym(name.to_string()))
}
pub trait AsVarName {
fn var_name(&self) -> &str;
fn var_expr(&self) -> E {
symbol(self.var_name())
}
}
impl AsVarName for &str {
fn var_name(&self) -> &str { self }
}
impl AsVarName for &&str {
fn var_name(&self) -> &str { self }
}
impl AsVarName for str {
fn var_name(&self) -> &str { self }
}
impl AsVarName for String {
fn var_name(&self) -> &str { self.as_str() }
}
impl AsVarName for &String {
fn var_name(&self) -> &str { self.as_str() }
}
impl AsVarName for &E {
fn var_name(&self) -> &str { (*self).var_name() }
fn var_expr(&self) -> E { (*self).clone() }
}
impl AsVarName for E {
fn var_name(&self) -> &str {
match self.as_ref() {
Expr::Sym(name) => name.as_str(),
_ => panic!("AsVarName::var_name: expected a symbol, got `{self}`"),
}
}
fn var_expr(&self) -> E { self.clone() }
}
#[macro_export]
macro_rules! symbols {
($($name:ident),+ $(,)?) => {
( $( $crate::symbol(stringify!($name)) ),+ , )
};
}
pub fn constant(val: f64) -> E {
E::new(Expr::Const(val))
}
impl From<f64> for E {
fn from(v: f64) -> E { constant(v) }
}
impl From<i64> for E {
fn from(v: i64) -> E { constant(v as f64) }
}
impl From<i32> for E {
fn from(v: i32) -> E { constant(v as f64) }
}
pub fn named_const(name: &str, value: f64, rust_f32: &str, rust_f64: &str, latex: &str) -> E {
E::new(Expr::NamedConst {
name: name.to_string(), value,
rust_f32: rust_f32.to_string(), rust_f64: rust_f64.to_string(),
latex: latex.to_string(),
})
}
pub fn pi() -> E {
named_const("pi", std::f64::consts::PI,
"std::f32::consts::PI", "std::f64::consts::PI", "\\pi")
}
pub fn epsilon() -> E {
named_const("epsilon", f64::EPSILON,
"f32::EPSILON", "f64::EPSILON", "\\epsilon")
}
pub fn euler() -> E {
named_const("e", std::f64::consts::E,
"std::f32::consts::E", "std::f64::consts::E", "e")
}
pub fn c(val: f64) -> E { constant(val) }
pub fn sin(e: E) -> E { E::new(Expr::Sin(e)) }
pub fn cos(e: E) -> E { E::new(Expr::Cos(e)) }
pub fn tan(e: E) -> E { E::new(Expr::Tan(e)) }
pub fn asin(e: E) -> E { E::new(Expr::Asin(e)) }
pub fn acos(e: E) -> E { E::new(Expr::Acos(e)) }
pub fn atan(e: E) -> E { E::new(Expr::Atan(e)) }
pub fn atan2(y: E, x: E) -> E { E::new(Expr::Atan2(y, x)) }
pub fn sinh(e: E) -> E { E::new(Expr::Sinh(e)) }
pub fn cosh(e: E) -> E { E::new(Expr::Cosh(e)) }
pub fn tanh(e: E) -> E { E::new(Expr::Tanh(e)) }
pub fn exp(e: E) -> E { E::new(Expr::Exp(e)) }
pub fn ln(e: E) -> E { E::new(Expr::Ln(e)) }
pub fn log2(e: E) -> E { E::new(Expr::Log2(e)) }
pub fn log10(e: E) -> E { E::new(Expr::Log10(e)) }
pub fn sqrt(e: E) -> E { E::new(Expr::Sqrt(e)) }
pub fn abs(e: E) -> E { E::new(Expr::Abs(e)) }
pub fn heaviside(e: E) -> E { E::new(Expr::Heaviside(e)) }
pub fn clamp(val: impl Into<E>, lo: impl Into<E>, hi: impl Into<E>) -> E {
E::new(Expr::Clamp(val.into(), lo.into(), hi.into()))
}
pub fn pow(base: impl Into<E>, exponent: impl Into<E>) -> E {
E::new(Expr::Pow(base.into(), exponent.into())).simplify()
}
#[derive(Clone, Copy)]
pub enum FunctionRef {
Unary(fn(E) -> E),
Binary(fn(E, E) -> E),
Ternary(fn(E, E, E) -> E),
}
pub const FUNCTIONS: &[(&str, FunctionRef)] = &[
("sin", FunctionRef::Unary(sin)),
("cos", FunctionRef::Unary(cos)),
("tan", FunctionRef::Unary(tan)),
("asin", FunctionRef::Unary(asin)),
("acos", FunctionRef::Unary(acos)),
("atan", FunctionRef::Unary(atan)),
("sinh", FunctionRef::Unary(sinh)),
("cosh", FunctionRef::Unary(cosh)),
("tanh", FunctionRef::Unary(tanh)),
("exp", FunctionRef::Unary(exp)),
("ln", FunctionRef::Unary(ln)),
("log2", FunctionRef::Unary(log2)),
("log10", FunctionRef::Unary(log10)),
("sqrt", FunctionRef::Unary(sqrt)),
("abs", FunctionRef::Unary(abs)),
("heaviside", FunctionRef::Unary(heaviside)),
("identity", FunctionRef::Unary(identity)),
("safe_sqrt", FunctionRef::Unary(safe_sqrt)),
("safe_asin", FunctionRef::Unary(safe_asin)),
("safe_acos", FunctionRef::Unary(safe_acos)),
("atan2", FunctionRef::Binary(atan2)),
("pow", FunctionRef::Binary(pow)),
("safe_atan2", FunctionRef::Binary(safe_atan2)),
("rad_diff", FunctionRef::Binary(rad_diff)),
("rad_sum", FunctionRef::Binary(rad_sum)),
("clamp", FunctionRef::Ternary(clamp)),
];
pub fn function_by_name(name: &str) -> Option<FunctionRef> {
FUNCTIONS.iter().find(|(n, _)| *n == name).map(|(_, f)| *f)
}
pub fn function_names() -> impl Iterator<Item = &'static str> {
FUNCTIONS.iter().map(|(n, _)| *n)
}
#[derive(Clone)]
pub struct FunctionBag {
table: std::collections::HashMap<String, BagFunction>,
}
#[derive(Clone)]
struct BagFunction {
params: std::vec::Vec<String>,
kind: FuncKind,
}
impl Default for FunctionBag {
fn default() -> Self { Self::new() }
}
fn extract_func_template(e: E, source: &str) -> Result<(String, std::vec::Vec<String>, FuncKind), String> {
match (*e.0).clone() {
Expr::Func { name, params, kind, .. } => Ok((name, params, kind)),
_ => Err(format!("{source}: expected Expr::Func, got a different expression")),
}
}
impl FunctionBag {
pub fn new() -> Self {
Self { table: std::collections::HashMap::new() }
}
pub fn add(&mut self, e: E) -> Result<(), String> {
let (name, params, kind) = extract_func_template(e, "FunctionBag::add")?;
self.table.insert(name, BagFunction { params, kind });
Ok(())
}
pub fn add1<F>(&mut self, f: F) -> Result<(), String>
where F: FnOnce(E) -> E
{
let e = f(symbol("__a0"));
let (name, params, kind) = extract_func_template(e, "FunctionBag::add1")?;
self.table.insert(name, BagFunction { params, kind });
Ok(())
}
pub fn add2<F>(&mut self, f: F) -> Result<(), String>
where F: FnOnce(E, E) -> E
{
let e = f(symbol("__a0"), symbol("__a1"));
let (name, params, kind) = extract_func_template(e, "FunctionBag::add2")?;
self.table.insert(name, BagFunction { params, kind });
Ok(())
}
#[allow(non_snake_case)]
pub fn addN<F>(&mut self, arity: usize, f: F) -> Result<(), String>
where F: FnOnce(std::vec::Vec<E>) -> E
{
let placeholders: std::vec::Vec<E> =
(0..arity).map(|i| symbol(&format!("__a{i}"))).collect();
let e = f(placeholders);
let (name, params, kind) = extract_func_template(e, "FunctionBag::addN")?;
self.table.insert(name, BagFunction { params, kind });
Ok(())
}
pub fn add_symbolic(&mut self, name: impl Into<String>, params: std::vec::Vec<String>, body: E) {
self.table.insert(
name.into(),
BagFunction { params, kind: FuncKind::Symbolic { body } },
);
}
pub fn add_with_kind(
&mut self,
name: impl Into<String>,
params: std::vec::Vec<String>,
kind: FuncKind,
) {
self.table.insert(name.into(), BagFunction { params, kind });
}
pub fn remove(&mut self, name: &str) -> bool {
self.table.remove(name).is_some()
}
pub fn contains(&self, name: &str) -> bool {
self.table.contains_key(name)
}
pub fn names(&self) -> std::vec::Vec<String> {
self.table.keys().cloned().collect()
}
pub fn entries(&self) -> impl Iterator<Item = (&str, usize)> {
self.table.iter().map(|(k, v)| (k.as_str(), v.params.len()))
}
pub fn get_info(&self, name: &str) -> Option<(&[String], &FuncKind)> {
let f = self.table.get(name)?;
Some((&f.params, &f.kind))
}
pub fn call(&self, name: &str, args: &[E]) -> Option<Result<E, String>> {
let f = self.table.get(name)?;
if args.len() != f.params.len() {
return Some(Err(format!(
"{} expects {} argument(s), got {}",
name, f.params.len(), args.len()
)));
}
let func = E::new(Expr::Func {
name: name.to_string(),
params: f.params.clone(),
kind: f.kind.clone(),
args: args.to_vec(),
});
Some(Ok(func))
}
}
impl std::ops::Add for E {
type Output = E;
fn add(self, rhs: E) -> E {
E::new(Expr::Add(self, rhs)).simplify()
}
}
impl std::ops::Sub for E {
type Output = E;
fn sub(self, rhs: E) -> E {
E::new(Expr::Sub(self, rhs)).simplify()
}
}
impl std::ops::Mul for E {
type Output = E;
fn mul(self, rhs: E) -> E {
E::new(Expr::Mul(self, rhs)).simplify()
}
}
impl std::ops::Div for E {
type Output = E;
fn div(self, rhs: E) -> E {
E::new(Expr::Div(self, rhs)).simplify()
}
}
impl std::ops::Neg for E {
type Output = E;
fn neg(self) -> E {
E::new(Expr::Neg(self)).simplify()
}
}
impl std::ops::Add<f64> for E {
type Output = E;
fn add(self, rhs: f64) -> E { E::new(Expr::Add(self, constant(rhs))).simplify() }
}
impl std::ops::Add<E> for f64 {
type Output = E;
fn add(self, rhs: E) -> E { E::new(Expr::Add(constant(self), rhs)).simplify() }
}
impl std::ops::Sub<f64> for E {
type Output = E;
fn sub(self, rhs: f64) -> E { E::new(Expr::Sub(self, constant(rhs))).simplify() }
}
impl std::ops::Sub<E> for f64 {
type Output = E;
fn sub(self, rhs: E) -> E { E::new(Expr::Sub(constant(self), rhs)).simplify() }
}
impl std::ops::Mul<f64> for E {
type Output = E;
fn mul(self, rhs: f64) -> E { E::new(Expr::Mul(self, constant(rhs))).simplify() }
}
impl std::ops::Mul<E> for f64 {
type Output = E;
fn mul(self, rhs: E) -> E { E::new(Expr::Mul(constant(self), rhs)).simplify() }
}
impl std::ops::Div<f64> for E {
type Output = E;
fn div(self, rhs: f64) -> E { E::new(Expr::Div(self, constant(rhs))).simplify() }
}
impl std::ops::Div<E> for f64 {
type Output = E;
fn div(self, rhs: E) -> E { E::new(Expr::Div(constant(self), rhs)).simplify() }
}
impl std::ops::Add<i64> for E {
type Output = E;
fn add(self, rhs: i64) -> E { E::new(Expr::Add(self, constant(rhs as f64))).simplify() }
}
impl std::ops::Add<E> for i64 {
type Output = E;
fn add(self, rhs: E) -> E { E::new(Expr::Add(constant(self as f64), rhs)).simplify() }
}
impl std::ops::Sub<i64> for E {
type Output = E;
fn sub(self, rhs: i64) -> E { E::new(Expr::Sub(self, constant(rhs as f64))).simplify() }
}
impl std::ops::Sub<E> for i64 {
type Output = E;
fn sub(self, rhs: E) -> E { E::new(Expr::Sub(constant(self as f64), rhs)).simplify() }
}
impl std::ops::Mul<i64> for E {
type Output = E;
fn mul(self, rhs: i64) -> E { E::new(Expr::Mul(self, constant(rhs as f64))).simplify() }
}
impl std::ops::Mul<E> for i64 {
type Output = E;
fn mul(self, rhs: E) -> E { E::new(Expr::Mul(constant(self as f64), rhs)).simplify() }
}
impl std::ops::Div<i64> for E {
type Output = E;
fn div(self, rhs: i64) -> E { E::new(Expr::Div(self, constant(rhs as f64))).simplify() }
}
impl std::ops::Div<E> for i64 {
type Output = E;
fn div(self, rhs: E) -> E { E::new(Expr::Div(constant(self as f64), rhs)).simplify() }
}
pub(crate) fn expand_func(params: &[String], body: &E, args: &[E]) -> E {
let mut expanded = body.clone();
for (p, a) in params.iter().zip(args.iter()) {
expanded = expanded.subs(p, a);
}
expanded
}
pub fn simple_func1(name: &str, body: impl Fn(E) -> E) -> impl Fn(E) -> E + Clone {
let name = name.to_string();
let body = body(symbol("__p0"));
move |arg: E| {
E::new(Expr::Func {
name: name.clone(),
params: vec!["__p0".to_string()],
kind: FuncKind::Symbolic { body: body.clone() },
args: vec![arg],
})
}
}
pub fn simple_func2(name: &str, body: impl Fn(E, E) -> E) -> impl Fn(E, E) -> E + Clone {
let name = name.to_string();
let body = body(symbol("__p0"), symbol("__p1"));
move |a: E, b: E| {
E::new(Expr::Func {
name: name.clone(),
params: vec!["__p0".to_string(), "__p1".to_string()],
kind: FuncKind::Symbolic { body: body.clone() },
args: vec![a, b],
})
}
}
pub fn simple_func(name: &str, arity: usize, body: impl Fn(Vec<E>) -> E) -> impl Fn(Vec<E>) -> E + Clone {
let name = name.to_string();
let params: Vec<String> = (0..arity).map(|i| format!("__p{}", i)).collect();
let syms: Vec<E> = params.iter().map(|p| symbol(p)).collect();
let body = body(syms);
move |args: Vec<E>| {
assert_eq!(args.len(), arity,
"custom function '{}' expects {} args, got {}", name, arity, args.len());
E::new(Expr::Func {
name: name.clone(),
params: params.clone(),
kind: FuncKind::Symbolic { body: body.clone() },
args,
})
}
}
pub fn simple_func1_derivs(
name: &str, body: impl Fn(E) -> E, derivs: impl Fn(E) -> [E; 1],
) -> impl Fn(E) -> E + Clone {
let name = name.to_string();
let p0 = symbol("__p0");
let body = body(p0.clone());
let d = derivs(p0);
move |a: E| {
E::new(Expr::Func {
name: name.clone(),
params: vec!["__p0".to_string()],
kind: FuncKind::SymbolicDerivs { body: body.clone(), derivs: vec![d[0].clone()] },
args: vec![a],
})
}
}
pub fn simple_func2_derivs(
name: &str, body: impl Fn(E, E) -> E, derivs: impl Fn(E, E) -> [E; 2],
) -> impl Fn(E, E) -> E + Clone {
let name = name.to_string();
let p0 = symbol("__p0");
let p1 = symbol("__p1");
let body = body(p0.clone(), p1.clone());
let d = derivs(p0, p1);
move |a: E, b: E| {
E::new(Expr::Func {
name: name.clone(),
params: vec!["__p0".to_string(), "__p1".to_string()],
kind: FuncKind::SymbolicDerivs { body: body.clone(), derivs: vec![d[0].clone(), d[1].clone()] },
args: vec![a, b],
})
}
}
pub fn simple_func_derivs(
name: &str, arity: usize, body: impl Fn(Vec<E>) -> E, derivs: impl Fn(Vec<E>) -> Vec<E>,
) -> impl Fn(Vec<E>) -> E + Clone {
let name = name.to_string();
let params: Vec<String> = (0..arity).map(|i| format!("__p{}", i)).collect();
let syms: Vec<E> = params.iter().map(|p| symbol(p)).collect();
let body = body(syms.clone());
let d = derivs(syms);
assert_eq!(d.len(), arity, "derivs must return {} elements", arity);
move |args: Vec<E>| {
assert_eq!(args.len(), arity,
"function '{}' expects {} args, got {}", name, arity, args.len());
E::new(Expr::Func {
name: name.clone(),
params: params.clone(),
kind: FuncKind::SymbolicDerivs { body: body.clone(), derivs: d.clone() },
args,
})
}
}
pub fn extern_func1(
name: &str, call_path: &str,
derivs: impl Fn(E) -> [E; 1],
eval_fn: fn(&[f64]) -> f64,
) -> impl Fn(E) -> E + Clone {
let name = name.to_string();
let call_path = call_path.to_string();
let d = derivs(symbol("__p0"));
move |a: E| {
E::new(Expr::Func {
name: name.clone(),
params: vec!["__p0".to_string()],
kind: FuncKind::Extern {
derivs: vec![d[0].clone()],
eval_fn,
call_path: call_path.clone(),
},
args: vec![a],
})
}
}
pub fn extern_func2(
name: &str, call_path: &str,
derivs: impl Fn(E, E) -> [E; 2],
eval_fn: fn(&[f64]) -> f64,
) -> impl Fn(E, E) -> E + Clone {
let name = name.to_string();
let call_path = call_path.to_string();
let d = derivs(symbol("__p0"), symbol("__p1"));
move |a: E, b: E| {
E::new(Expr::Func {
name: name.clone(),
params: vec!["__p0".to_string(), "__p1".to_string()],
kind: FuncKind::Extern {
derivs: vec![d[0].clone(), d[1].clone()],
eval_fn,
call_path: call_path.clone(),
},
args: vec![a, b],
})
}
}
pub fn extern_func(
name: &str, arity: usize, call_path: &str,
derivs: impl Fn(Vec<E>) -> Vec<E>,
eval_fn: fn(&[f64]) -> f64,
) -> impl Fn(Vec<E>) -> E + Clone {
let name = name.to_string();
let call_path = call_path.to_string();
let params: Vec<String> = (0..arity).map(|i| format!("__p{}", i)).collect();
let syms: Vec<E> = params.iter().map(|p| symbol(p)).collect();
let d = derivs(syms);
assert_eq!(d.len(), arity, "derivs must return {} elements", arity);
move |args: Vec<E>| {
assert_eq!(args.len(), arity,
"extern function '{}' expects {} args, got {}", name, arity, args.len());
E::new(Expr::Func {
name: name.clone(),
params: params.clone(),
kind: FuncKind::Extern {
derivs: d.clone(),
eval_fn,
call_path: call_path.clone(),
},
args,
})
}
}
pub fn grad1(body: impl Fn(E) -> E) -> impl Fn(E) -> [E; 1] + Clone {
let p = symbol("__g0");
let d = body(p).diff("__g0");
move |a: E| { [d.subs("__g0", &a)] }
}
pub fn grad2(body: impl Fn(E, E) -> E) -> impl Fn(E, E) -> [E; 2] + Clone {
let p0 = symbol("__g0");
let p1 = symbol("__g1");
let expr = body(p0, p1);
let d0 = expr.diff("__g0");
let d1 = expr.diff("__g1");
move |a: E, b: E| {
[d0.subs("__g0", &a).subs("__g1", &b),
d1.subs("__g0", &a).subs("__g1", &b)]
}
}
fn rad2rad(v: f64) -> f64 {
use std::f64::consts::PI;
if !(-PI..=PI).contains(&v) {
v - (2.0 * PI) * (v / (2.0 * PI) + 0.5).floor()
} else {
v
}
}
pub fn rad_diff(a: E, b: E) -> E {
extern_func2("rad_diff", "arael::utils::rad_diff",
grad2(|a, b| a - b),
|args: &[f64]| rad2rad(args[0] - args[1]))(a, b)
}
pub fn rad_sum(a: E, b: E) -> E {
extern_func2("rad_sum", "arael::utils::rad_sum",
grad2(|a, b| a + b),
|args: &[f64]| rad2rad(args[0] + args[1]))(a, b)
}
pub fn identity(x: E) -> E {
simple_func1("identity", |t| t)(x)
}
pub fn safe_atan2(y: E, x: E) -> E {
simple_func2_derivs("safe_atan2",
atan2,
|y, x| {
let eps2 = epsilon() * epsilon();
let d = x.clone()*x.clone() + y.clone()*y.clone() + eps2;
[x / d.clone(), -y / d]
})(y, x)
}
pub fn safe_asin(x: E) -> E {
simple_func1_derivs("safe_asin",
|x| asin(clamp(x, c(-1.0), c(1.0))),
|x| {
let xc = clamp(x, c(-1.0), c(1.0));
[c(1.0) / sqrt(identity(c(1.0) - xc.clone()*xc) + epsilon()*epsilon())]
}
)(x)
}
pub fn safe_acos(x: E) -> E {
simple_func1_derivs("safe_acos",
|x| acos(clamp(x, c(-1.0), c(1.0))),
|x| {
let xc = clamp(x, c(-1.0), c(1.0));
[-c(1.0) / sqrt(identity(c(1.0) - xc.clone()*xc) + epsilon()*epsilon())]
}
)(x)
}
pub fn safe_sqrt(x: E) -> E {
extern_func1("safe_sqrt", "arael::utils::safe_sqrt",
|x| [c(0.5) / sqrt(identity(x.clone() * heaviside(x)) + epsilon()*epsilon())],
|args| {
let v = args[0];
if v <= 0.0 { 0.0 } else { v.sqrt() }
}
)(x)
}
pub use linalg::{SymVec, SymMat, jacobian};
pub use parse::{parse, parse_with_functions, ParseError};
pub use geo::{vect2sym, vect3sym, matrix2sym, matrix3sym, quaternsym};
pub use cse::cse;
pub use arael_sym_macros::sym;
#[cfg(test)]
mod tests {
use super::*;
use std::collections::HashMap;
#[test]
fn simple_func_identity_display() {
sym! {
let identity = simple_func1("identity", |t| t);
let x = symbol("x");
assert_eq!(format!("{}", identity(x)), "identity(x)");
}
}
#[test]
fn simple_func_identity_diff() {
sym! {
let identity = simple_func1("identity", |t| t);
let x = symbol("x");
let f = identity(x);
assert_eq!(format!("{}", f.diff("x")), "1");
}
}
#[test]
fn simple_func_identity_chain_rule() {
sym! {
let identity = simple_func1("identity", |t| t);
let x = symbol("x");
let f = identity(x * x);
assert_eq!(format!("{}", f.diff("x")), "2 * x");
}
}
#[test]
fn simple_func_identity_eval() {
sym! {
let identity = simple_func1("identity", |t| t);
let x = symbol("x");
let f = identity(x);
let vars = HashMap::from([("x", 5.0)]);
assert_eq!(f.eval(&vars).unwrap(), 5.0);
}
}
#[test]
fn simple_func_square() {
sym! {
let square = simple_func1("square", |t| t * t);
let x = symbol("x");
let f = square(x + 1.0);
assert_eq!(format!("{}", f), "square(x + 1)");
assert_eq!(format!("{}", f.diff("x")), "2 * (x + 1)");
}
}
#[test]
fn simple_func_square_eval() {
sym! {
let square = simple_func1("square", |t| t * t);
let x = symbol("x");
let f = square(x);
let vars = HashMap::from([("x", 4.0)]);
assert_eq!(f.eval(&vars).unwrap(), 16.0);
}
}
#[test]
fn simple_func_binary() {
sym! {
let f = simple_func2("prod", |a, b| a * b);
let x = symbol("x");
let y = symbol("y");
let result = f(x, y);
assert_eq!(format!("{}", result), "prod(x, y)");
assert_eq!(format!("{}", result.diff("x")), "y");
assert_eq!(format!("{}", result.diff("y")), "x");
}
}
#[test]
fn simple_func_nested() {
sym! {
let identity = simple_func1("identity", |t| t);
let square = simple_func1("square", |t| t * t);
let x = symbol("x");
let f = identity(square(x));
assert_eq!(format!("{}", f), "identity(square(x))");
assert_eq!(format!("{}", f.diff("x")), "2 * x");
}
}
#[test]
fn simple_func_my_sin() {
sym! {
let my_sin = simple_func1("my_sin", |t| sin(t));
let x = symbol("x");
let f = my_sin(x);
assert_eq!(format!("{}", f), "my_sin(x)");
assert_eq!(format!("{}", f.diff("x")), "cos(x)");
}
}
#[test]
fn simple_func_my_sin_chain_rule() {
sym! {
let my_sin = simple_func1("my_sin", |t| sin(t));
let x = symbol("x");
let f = my_sin(x * x);
assert_eq!(format!("{}", f.diff("x")), "2 * x * cos(x^2)");
}
}
#[test]
fn simple_func_to_rust() {
sym! {
let identity = simple_func1("identity", |t| t);
let x = symbol("x");
let f = identity(x);
assert_eq!(f.to_rust("f64"), "x");
}
}
#[test]
fn simple_func_latex() {
sym! {
let identity = simple_func1("identity", |t| t);
let x = symbol("x");
let f = identity(x);
assert_eq!(f.to_latex(), "\\operatorname{identity}\\left(x\\right)");
}
}
#[test]
fn simple_func_free_vars() {
sym! {
let identity = simple_func1("identity", |t| t);
let x = symbol("x");
let f = identity(x + symbol("y"));
let vars = f.free_vars();
assert!(vars.contains("x"));
assert!(vars.contains("y"));
assert!(!vars.contains("t"));
}
}
#[test]
fn simple_func_subs() {
sym! {
let identity = simple_func1("identity", |t| t);
let x = symbol("x");
let f = identity(x);
let g = f.subs("x", &constant(3.0));
assert_eq!(format!("{}", g), "identity(3)");
}
}
#[test]
fn simple_func_simplify_constants() {
sym! {
let square = simple_func1("square", |t| t * t);
let f = square(constant(3.0));
let s = f.simplify();
assert_eq!(format!("{}", s), "9");
}
}
#[test]
fn simple_func_nary() {
sym! {
let f = simple_func("triple_sum", 3, |v| v[0].clone() + v[1].clone() + v[2].clone());
let x = symbol("x");
let y = symbol("y");
let z = symbol("z");
let result = f(vec![x, y, z]);
assert_eq!(format!("{}", result), "triple_sum(x, y, z)");
assert_eq!(format!("{}", result.diff("x")), "1");
}
}
#[test]
fn simple_func_expand() {
sym! {
let square = simple_func1("square", |t| t * t);
let x = symbol("x");
let f = square(x + 1.0);
let expanded = f.expand();
assert_eq!(format!("{}", expanded), "x^2 + 2 * x + 1");
}
}
#[test]
fn simple_func_derivs_codegen() {
sym! {
let f = simple_func1_derivs("inv", |t| 1.0 / t, |t| [-1.0 / (t * t)]);
let x = symbol("x");
assert_eq!(f(x).to_rust("f64"), "1.0_f64 / x");
}
}
#[test]
fn safe_atan2_diff() {
sym! {
let a = symbol("a");
let b = symbol("b");
let f = safe_atan2(a, b);
let da = f.diff("a");
let vars = HashMap::from([("a", 1.0), ("b", 1.0)]);
let v = da.eval(&vars).unwrap();
assert!((v - 0.5).abs() < 1e-10, "d/da at (1,1) = {}, expected 0.5", v);
}
}
#[test]
fn safe_atan2_eval() {
sym! {
let a = symbol("a");
let b = symbol("b");
let f = safe_atan2(a, b);
let vars = HashMap::from([("a", 1.0), ("b", 1.0)]);
let v = f.eval(&vars).unwrap();
assert!((v - std::f64::consts::FRAC_PI_4).abs() < 1e-10);
}
}
#[test]
fn safe_atan2_chain_rule() {
sym! {
let t = symbol("t");
let f = safe_atan2(sin(t), cos(t));
let df = f.diff("t");
let vars = HashMap::from([("t", 0.5)]);
let v = df.eval(&vars).unwrap();
assert!((v - 1.0).abs() < 1e-8, "df/dt at t=0.5 = {}, expected 1", v);
}
}
#[test]
fn safe_atan2_at_zero() {
sym! {
let a = symbol("a");
let b = symbol("b");
let da = safe_atan2(a, b).diff("a");
let vars = HashMap::from([("a", 0.0), ("b", 0.0)]);
let v = da.eval(&vars).unwrap();
assert!(v.is_finite(), "derivative at (0,0) should be finite, got {}", v);
}
}
#[test]
fn safe_asin_eval() {
sym! {
let x = symbol("x");
let f = safe_asin(x);
let vars = HashMap::from([("x", 0.5)]);
assert!((f.eval(&vars).unwrap() - 0.5_f64.asin()).abs() < 1e-10);
let vars = HashMap::from([("x", 1.5)]);
assert!((f.eval(&vars).unwrap() - std::f64::consts::FRAC_PI_2).abs() < 1e-10);
}
}
#[test]
fn safe_asin_deriv_finite() {
sym! {
let x = symbol("x");
let da = safe_asin(x).diff("x");
let vars = HashMap::from([("x", 1.0)]);
let v = da.eval(&vars).unwrap();
assert!(v.is_finite(), "safe_asin derivative at 1.0 should be finite, got {}", v);
}
}
#[test]
fn safe_acos_eval() {
sym! {
let x = symbol("x");
let f = safe_acos(x);
let vars = HashMap::from([("x", 0.5)]);
assert!((f.eval(&vars).unwrap() - 0.5_f64.acos()).abs() < 1e-10);
let vars = HashMap::from([("x", -1.5)]);
assert!((f.eval(&vars).unwrap() - std::f64::consts::PI).abs() < 1e-10);
}
}
#[test]
fn identity_codegen_parens() {
sym! {
let x = symbol("x");
let f = identity(c(1.0) - x * x) + epsilon * epsilon;
let code = f.to_rust("f64");
assert!(code.contains("(-x.powf(2.0_f64) + 1.0_f64)"),
"expected parens around identity body, got: {}", code);
}
}
#[test]
fn identity_diff() {
sym! {
let x = symbol("x");
let f = identity(x * x);
assert_eq!(format!("{}", f.diff("x")), "2 * x");
}
}
#[test]
fn safe_acos_deriv_finite() {
sym! {
let x = symbol("x");
let da = safe_acos(x).diff("x");
let vars = HashMap::from([("x", 1.0)]);
let v = da.eval(&vars).unwrap();
assert!(v.is_finite(), "safe_acos derivative at 1.0 should be finite, got {}", v);
}
}
#[test]
fn safe_derivs_finite_outside_domain() {
sym! {
let x = symbol("x");
let d_asin = safe_asin(x).diff("x");
let d_acos = safe_acos(x).diff("x");
let d_sqrt = safe_sqrt(x).diff("x");
for v in [-5.0_f64, -1.5, 1.5, 5.0] {
let vars = HashMap::from([("x", v)]);
let a = d_asin.eval(&vars).unwrap();
let c = d_acos.eval(&vars).unwrap();
assert!(a.is_finite(), "safe_asin'({}) should be finite, got {}", v, a);
assert!(c.is_finite(), "safe_acos'({}) should be finite, got {}", v, c);
}
for v in [-5.0_f64, -1.0, -1e-12, 0.0] {
let vars = HashMap::from([("x", v)]);
let s = d_sqrt.eval(&vars).unwrap();
assert!(s.is_finite(), "safe_sqrt'({}) should be finite, got {}", v, s);
}
}
}
#[test]
fn safe_sqrt_eval() {
sym! {
let x = symbol("x");
let f = safe_sqrt(x);
let vars = HashMap::from([("x", 4.0)]);
assert!((f.eval(&vars).unwrap() - 2.0).abs() < 1e-10);
let vars = HashMap::from([("x", -1e-10)]);
assert!(f.eval(&vars).unwrap().abs() < 1e-10);
let vars = HashMap::from([("x", 0.0)]);
assert!(f.eval(&vars).unwrap().abs() < 1e-10);
}
}
#[test]
fn safe_sqrt_deriv_at_zero() {
sym! {
let x = symbol("x");
let df = safe_sqrt(x).diff("x");
let vars = HashMap::from([("x", 0.0)]);
let v = df.eval(&vars).unwrap();
assert!(v.is_finite(), "safe_sqrt derivative at 0 should be finite, got {}", v);
}
}
#[test]
fn grad2_basic() {
sym! {
let g = grad2(|a, b| a * b);
let x = symbol("x");
let y = symbol("y");
let [da, db] = g(x, y);
assert_eq!(format!("{}", da), "y");
assert_eq!(format!("{}", db), "x");
}
}
#[test]
fn grad1_basic() {
sym! {
let g = grad1(|t| t * t);
let x = symbol("x");
let [dt] = g(x);
assert_eq!(format!("{}", dt), "2 * x");
}
}
#[test]
fn extern_func_display() {
sym! {
let x = symbol("x");
let y = symbol("y");
let f = rad_diff(x, y);
assert_eq!(format!("{}", f), "rad_diff(x, y)");
}
}
#[test]
fn extern_func_diff() {
sym! {
let x = symbol("x");
let y = symbol("y");
let f = rad_diff(x, y);
assert_eq!(format!("{}", f.diff("x")), "1");
assert_eq!(format!("{}", f.diff("y")), "-1");
}
}
#[test]
fn extern_func_chain_rule() {
sym! {
let x = symbol("x");
let y = symbol("y");
let f = rad_diff(x * x, y);
assert_eq!(format!("{}", f.diff("x")), "2 * x");
}
}
#[test]
fn extern_func_eval() {
sym! {
let x = symbol("x");
let y = symbol("y");
let f = rad_diff(x, y);
let vars = HashMap::from([("x", 0.3), ("y", 0.1)]);
let v = f.eval(&vars).unwrap();
assert!((v - 0.2).abs() < 1e-10);
}
}
#[test]
fn extern_func_eval_wrapping() {
sym! {
let x = symbol("x");
let f = rad_diff(constant(0.0), x);
let vars = HashMap::from([("x", 2.0 * std::f64::consts::PI)]);
let v = f.eval(&vars).unwrap();
assert!(v.abs() < 1e-10, "rad_diff(0, 2*pi) = {}, expected 0", v);
}
}
#[test]
fn extern_func_to_rust() {
sym! {
let x = symbol("x");
let y = symbol("y");
let f = rad_diff(x, y);
let code = f.to_rust("f64");
assert_eq!(code, "arael::utils::rad_diff(x, y)");
}
}
#[test]
fn extern_func_latex() {
sym! {
let x = symbol("x");
let y = symbol("y");
let f = rad_diff(x, y);
assert_eq!(f.to_latex(), "\\operatorname{rad\\_diff}\\left(x, y\\right)");
}
}
#[test]
fn extern_func_subs() {
sym! {
let x = symbol("x");
let y = symbol("y");
let f = rad_diff(x, y);
let g = f.subs("x", &constant(1.0));
assert_eq!(format!("{}", g), "rad_diff(1, y)");
}
}
#[test]
fn extern_func_no_const_fold() {
sym! {
let f = rad_diff(constant(1.0), constant(2.0));
let s = f.simplify();
assert_eq!(format!("{}", s), "rad_diff(1, 2)");
}
}
#[test]
fn extern_func_no_expand() {
sym! {
let x = symbol("x");
let y = symbol("y");
let f = rad_diff(x + 1.0, y);
let expanded = f.expand();
assert_eq!(format!("{}", expanded), "rad_diff(x + 1, y)");
}
}
#[test]
fn extern_func_free_vars() {
sym! {
let x = symbol("x");
let y = symbol("y");
let f = rad_diff(x, y);
let vars = f.free_vars();
assert!(vars.contains("x"));
assert!(vars.contains("y"));
assert!(!vars.contains("__a"));
assert!(!vars.contains("__b"));
}
}
#[test]
fn rad_sum_diff() {
sym! {
let x = symbol("x");
let y = symbol("y");
let f = rad_sum(x, y);
assert_eq!(format!("{}", f.diff("x")), "1");
assert_eq!(format!("{}", f.diff("y")), "1");
}
}
#[test]
fn rad_sum_to_rust() {
sym! {
let x = symbol("x");
let y = symbol("y");
let f = rad_sum(x, y);
assert_eq!(f.to_rust("f64"), "arael::utils::rad_sum(x, y)");
}
}
#[test]
fn extern_func_def() {
sym! {
fn my_eval(args: &[f64]) -> f64 { args[0] - args[1] }
let my_diff = extern_func2("my_diff", "my_mod::diff",
grad2(|a, b| a - b), my_eval);
let x = symbol("x");
let y = symbol("y");
let f = my_diff(x, y);
assert_eq!(format!("{}", f), "my_diff(x, y)");
assert_eq!(format!("{}", f.diff("x")), "1");
assert_eq!(format!("{}", f.diff("y")), "-1");
assert_eq!(f.to_rust("f64"), "my_mod::diff(x, y)");
}
}
#[test]
fn heaviside_eval() {
let vars = HashMap::from([("x", 0.0)]);
sym! {
let x = symbol("x");
let h = heaviside(x);
assert_eq!(h.eval(&HashMap::from([("x", -1.0)])).unwrap(), 0.0);
assert_eq!(h.eval(&vars).unwrap(), 1.0);
assert_eq!(h.eval(&HashMap::from([("x", 3.0)])).unwrap(), 1.0);
}
}
#[test]
fn heaviside_diff() {
sym! {
let x = symbol("x");
assert_eq!(format!("{}", heaviside(x).diff("x")), "0");
assert_eq!(format!("{}", heaviside(x * x - 1.0).diff("x")), "0");
}
}
#[test]
fn heaviside_display() {
sym! {
let x = symbol("x");
assert_eq!(format!("{}", heaviside(x)), "H(x)");
}
}
#[test]
fn heaviside_composition_diff() {
sym! {
let x = symbol("x");
let f = heaviside(1.0 - x) * x * x;
assert_eq!(format!("{}", f.diff("x")), "2 * x * H(-x + 1)");
}
}
#[test]
fn clamp_eval() {
sym! {
let x = symbol("x");
let f = clamp(x, c(0.0), c(1.0));
assert_eq!(f.eval(&HashMap::from([("x", 0.5)])).unwrap(), 0.5);
assert_eq!(f.eval(&HashMap::from([("x", -2.0)])).unwrap(), 0.0);
assert_eq!(f.eval(&HashMap::from([("x", 5.0)])).unwrap(), 1.0);
}
}
#[test]
fn clamp_diff_passthrough() {
sym! {
let x = symbol("x");
assert_eq!(format!("{}", clamp(x, c(0.0), c(1.0)).diff("x")), "1");
assert_eq!(format!("{}", clamp(x * x, c(0.0), c(1.0)).diff("x")), "2 * x");
}
}
#[test]
fn clamp_display() {
sym! {
let x = symbol("x");
assert_eq!(format!("{}", clamp(x, c(0.0), c(1.0))), "clamp(x, 0, 1)");
}
}
#[test]
fn clamp_simplify_constants() {
sym! {
let f = clamp(c(5.0), c(0.0), c(1.0));
assert_eq!(format!("{}", f.simplify()), "1");
let g = clamp(c(-3.0), c(0.0), c(1.0));
assert_eq!(format!("{}", g.simplify()), "0");
let h = clamp(c(0.5), c(0.0), c(1.0));
assert_eq!(format!("{}", h.simplify()), "0.5");
}
}
#[test]
fn clamp_asin_eval() {
sym! {
let my_asin = simple_func1("my_asin", |t| asin(clamp(t, c(-1.0), c(1.0))));
let x = symbol("x");
let f = my_asin(x);
let val = f.eval(&HashMap::from([("x", 0.5)])).unwrap();
assert!((val - 0.5_f64.asin()).abs() < 1e-10);
let val_hi = f.eval(&HashMap::from([("x", 1.5)])).unwrap();
assert!((val_hi - std::f64::consts::FRAC_PI_2).abs() < 1e-10);
let val_lo = f.eval(&HashMap::from([("x", -1.5)])).unwrap();
assert!((val_lo + std::f64::consts::FRAC_PI_2).abs() < 1e-10);
}
}
#[test]
fn clamp_asin_diff() {
sym! {
let my_asin = simple_func1("my_asin", |t| asin(clamp(t, c(-1.0), c(1.0))));
let x = symbol("x");
let f = my_asin(x);
let df = f.diff("x");
let vars = HashMap::from([("x", 0.5)]);
let dval = df.eval(&vars).unwrap();
let expected = 1.0 / (1.0 - 0.25_f64).sqrt(); assert!((dval - expected).abs() < 1e-10);
}
}
#[test]
fn heaviside_to_rust() {
sym! {
let x = symbol("x");
assert_eq!(heaviside(x).to_rust("f64"), "x.heaviside()");
}
}
#[test]
fn clamp_to_rust() {
sym! {
let x = symbol("x");
assert_eq!(clamp(x, c(0.0), c(1.0)).to_rust("f64"), "x.clamp(0.0_f64, 1.0_f64)");
}
}
#[test]
fn parse_heaviside() {
let f = parse("H(x)").unwrap();
assert_eq!(format!("{}", f), "H(x)");
assert_eq!(format!("{}", f.diff("x")), "0");
}
#[test]
fn parse_clamp() {
let f = parse("clamp(x, 0, 1)").unwrap();
assert_eq!(format!("{}", f), "clamp(x, 0, 1)");
assert_eq!(format!("{}", f.diff("x")), "1");
}
#[test]
fn named_const_pi_display() {
assert_eq!(format!("{}", pi()), "pi");
}
#[test]
fn named_const_pi_eval() {
let vars = HashMap::new();
assert_eq!(pi().eval(&vars).unwrap(), std::f64::consts::PI);
}
#[test]
fn named_const_pi_diff() {
assert_eq!(format!("{}", pi().diff("x")), "0");
}
#[test]
fn named_const_pi_codegen() {
assert_eq!(pi().to_rust("f64"), "std::f64::consts::PI");
assert_eq!(pi().to_rust("f32"), "std::f32::consts::PI");
}
#[test]
fn named_const_pi_latex() {
assert_eq!(pi().to_latex(), "\\pi");
}
#[test]
fn named_const_epsilon_display() {
assert_eq!(format!("{}", epsilon()), "epsilon");
}
#[test]
fn named_const_epsilon_eval() {
let vars = HashMap::new();
assert_eq!(epsilon().eval(&vars).unwrap(), f64::EPSILON);
}
#[test]
fn named_const_epsilon_codegen() {
assert_eq!(epsilon().to_rust("f64"), "f64::EPSILON");
assert_eq!(epsilon().to_rust("f32"), "f32::EPSILON");
}
#[test]
fn named_const_euler_display() {
assert_eq!(format!("{}", euler()), "e");
}
#[test]
fn named_const_euler_eval() {
let vars = HashMap::new();
assert_eq!(euler().eval(&vars).unwrap(), std::f64::consts::E);
}
#[test]
fn named_const_euler_codegen() {
assert_eq!(euler().to_rust("f64"), "std::f64::consts::E");
}
#[test]
fn named_const_epsilon_survives_simplification() {
sym! {
let x = symbol("x");
let f = (x + epsilon()).simplify();
assert_eq!(format!("{}", f), "x + epsilon");
}
}
#[test]
fn named_const_not_free_var() {
sym! {
let x = symbol("x");
let f = x + pi();
let vars = f.free_vars();
assert!(vars.contains("x"));
assert!(!vars.contains("pi"));
}
}
#[test]
fn named_const_custom() {
let tau = named_const("tau", std::f64::consts::TAU,
"std::f32::consts::TAU", "std::f64::consts::TAU", "\\tau");
assert_eq!(format!("{}", tau), "tau");
let vars = HashMap::new();
assert_eq!(tau.eval(&vars).unwrap(), std::f64::consts::TAU);
assert_eq!(tau.to_rust("f64"), "std::f64::consts::TAU");
assert_eq!(tau.to_latex(), "\\tau");
}
#[test]
fn named_const_pi_add_pi() {
sym! {
let f = (pi() + pi()).simplify();
assert_eq!(format!("{}", f), "2 * pi");
}
}
#[test]
fn named_const_pi_sub_pi() {
sym! {
let f = (pi() - pi()).simplify();
assert_eq!(format!("{}", f), "0");
}
}
#[test]
fn named_const_pi_mul_pi() {
sym! {
let f = (pi() * pi()).simplify();
assert_eq!(format!("{}", f), "pi^2");
}
}
#[test]
fn named_const_epsilon_add() {
sym! {
let x = symbol("x");
let f = (x + epsilon() + epsilon()).simplify();
assert_eq!(format!("{}", f), "x + 2 * epsilon");
}
}
#[test]
fn trig_sin_pi() {
sym! { assert_eq!(format!("{}", sin(pi()).simplify()), "0"); }
}
#[test]
fn trig_cos_pi() {
sym! { assert_eq!(format!("{}", cos(pi()).simplify()), "-1"); }
}
#[test]
fn trig_sin_pi_half() {
sym! { assert_eq!(format!("{}", sin(pi() / 2.0).simplify()), "1"); }
}
#[test]
fn trig_cos_pi_half() {
sym! { assert_eq!(format!("{}", cos(pi() / 2.0).simplify()), "0"); }
}
#[test]
fn trig_sin_pi_quarter() {
sym! {
let f = sin(pi() / 4.0).simplify();
let vars = HashMap::new();
let v = f.eval(&vars).unwrap();
assert!((v - std::f64::consts::FRAC_1_SQRT_2).abs() < 1e-10);
}
}
#[test]
fn trig_cos_pi_third() {
sym! {
let f = cos(pi() / 3.0).simplify();
assert_eq!(format!("{}", f), "0.5");
}
}
#[test]
fn trig_sin_2pi() {
sym! { assert_eq!(format!("{}", sin(2.0 * pi()).simplify()), "0"); }
}
#[test]
fn trig_cos_2pi() {
sym! { assert_eq!(format!("{}", cos(2.0 * pi()).simplify()), "1"); }
}
#[test]
fn trig_tan_pi() {
sym! { assert_eq!(format!("{}", tan(pi()).simplify()), "0"); }
}
#[test]
fn trig_sin_pi_sixth() {
sym! { assert_eq!(format!("{}", sin(pi() / 6.0).simplify()), "0.5"); }
}
#[test]
fn ln_e() {
sym! { assert_eq!(format!("{}", ln(euler()).simplify()), "1"); }
}
#[test]
fn sym_macro_bare_pi() {
sym! {
let x = symbol("x");
let f = 2.0 * pi * x;
assert_eq!(format!("{}", f), "2 * x * pi");
}
}
#[test]
fn sym_macro_bare_epsilon() {
sym! {
let x = symbol("x");
let f = x * x + epsilon;
assert_eq!(format!("{}", f), "x^2 + epsilon");
}
}
#[test]
fn sym_macro_pi_call_still_works() {
sym! {
let f = pi();
assert_eq!(format!("{}", f), "pi");
}
}
#[test]
fn ln_e_pow_x() {
sym! {
let x = symbol("x");
let f = ln(pow(euler(), x)).simplify();
assert_eq!(format!("{}", f), "x");
}
}
}