#![allow(clippy::should_implement_trait)]
mod diff;
mod eval;
mod fmt;
mod simplify;
mod linalg;
mod parse;
pub mod geo;
pub mod cse;
use std::hash::{Hash, Hasher};
use std::rc::Rc;
#[derive(Clone, PartialEq)]
pub struct E(Rc<Expr>);
impl Eq for E {}
impl E {
fn new(expr: Expr) -> E {
E(Rc::new(expr))
}
pub fn symbols(&self) -> std::collections::HashSet<String> {
let mut out = std::collections::HashSet::new();
self.collect_symbols(&mut out);
out
}
fn collect_symbols(&self, out: &mut std::collections::HashSet<String>) {
match &*self.0 {
Expr::Sym(s) => { out.insert(s.clone()); }
Expr::Const(_) => {}
Expr::Neg(a) | Expr::Sin(a) | Expr::Cos(a) | Expr::Tan(a)
| Expr::Asin(a) | Expr::Acos(a) | Expr::Atan(a)
| Expr::Sinh(a) | Expr::Cosh(a) | Expr::Tanh(a)
| Expr::Exp(a) | Expr::Ln(a) | Expr::Log2(a) | Expr::Log10(a)
| Expr::Sqrt(a) | Expr::Abs(a) => { a.collect_symbols(out); }
Expr::Add(a, b) | Expr::Sub(a, b) | Expr::Mul(a, b)
| Expr::Div(a, b) | Expr::Pow(a, b) | Expr::Atan2(a, b) => {
a.collect_symbols(out);
b.collect_symbols(out);
}
}
}
pub fn substitute(&self, subs: &[(E, E)]) -> E {
for (from, to) in subs {
if self == from { return to.clone(); }
}
match &*self.0 {
Expr::Sym(_) | Expr::Const(_) => self.clone(),
Expr::Neg(a) => -a.substitute(subs),
Expr::Add(a, b) => a.substitute(subs) + b.substitute(subs),
Expr::Sub(a, b) => a.substitute(subs) - b.substitute(subs),
Expr::Mul(a, b) => a.substitute(subs) * b.substitute(subs),
Expr::Div(a, b) => a.substitute(subs) / b.substitute(subs),
Expr::Pow(a, b) => pow(a.substitute(subs), b.substitute(subs)),
Expr::Sin(a) => sin(a.substitute(subs)),
Expr::Cos(a) => cos(a.substitute(subs)),
Expr::Tan(a) => tan(a.substitute(subs)),
Expr::Asin(a) => asin(a.substitute(subs)),
Expr::Acos(a) => acos(a.substitute(subs)),
Expr::Atan(a) => atan(a.substitute(subs)),
Expr::Atan2(a, b) => atan2(a.substitute(subs), b.substitute(subs)),
Expr::Sinh(a) => sinh(a.substitute(subs)),
Expr::Cosh(a) => cosh(a.substitute(subs)),
Expr::Tanh(a) => tanh(a.substitute(subs)),
Expr::Exp(a) => exp(a.substitute(subs)),
Expr::Ln(a) => ln(a.substitute(subs)),
Expr::Log2(a) => log2(a.substitute(subs)),
Expr::Log10(a) => ln(a.substitute(subs)) / ln(constant(10.0)),
Expr::Sqrt(a) => sqrt(a.substitute(subs)),
Expr::Abs(a) => abs(a.substitute(subs)),
}
}
}
impl std::ops::Deref for E {
type Target = Expr;
fn deref(&self) -> &Expr {
&self.0
}
}
impl AsRef<Expr> for E {
fn as_ref(&self) -> &Expr {
&self.0
}
}
#[derive(Debug, Clone, PartialEq)]
pub enum Expr {
Sym(String),
Const(f64),
Neg(E),
Add(E, E),
Sub(E, E),
Mul(E, E),
Div(E, E),
Pow(E, E),
Sin(E),
Cos(E),
Tan(E),
Asin(E),
Acos(E),
Atan(E),
Atan2(E, E),
Sinh(E),
Cosh(E),
Tanh(E),
Exp(E),
Ln(E),
Log2(E),
Log10(E),
Sqrt(E),
Abs(E),
}
impl Eq for Expr {}
impl Hash for Expr {
fn hash<H: Hasher>(&self, state: &mut H) {
std::mem::discriminant(self).hash(state);
match self {
Expr::Sym(s) => s.hash(state),
Expr::Const(v) => v.to_bits().hash(state),
Expr::Neg(a) | Expr::Sin(a) | Expr::Cos(a) | Expr::Tan(a)
| Expr::Asin(a) | Expr::Acos(a) | Expr::Atan(a)
| Expr::Sinh(a) | Expr::Cosh(a) | Expr::Tanh(a)
| Expr::Exp(a) | Expr::Ln(a) | Expr::Log2(a) | Expr::Log10(a)
| Expr::Sqrt(a) | Expr::Abs(a) => a.hash(state),
Expr::Add(a, b) | Expr::Sub(a, b) | Expr::Mul(a, b)
| Expr::Div(a, b) | Expr::Pow(a, b) | Expr::Atan2(a, b) => {
a.hash(state);
b.hash(state);
}
}
}
}
impl Hash for E {
fn hash<H: Hasher>(&self, state: &mut H) {
self.0.hash(state);
}
}
pub fn symbol(name: &str) -> E {
E::new(Expr::Sym(name.to_string()))
}
pub fn constant(val: f64) -> E {
E::new(Expr::Const(val))
}
pub fn c(val: f64) -> E { constant(val) }
pub fn sin(e: E) -> E { E::new(Expr::Sin(e)) }
pub fn cos(e: E) -> E { E::new(Expr::Cos(e)) }
pub fn tan(e: E) -> E { E::new(Expr::Tan(e)) }
pub fn asin(e: E) -> E { E::new(Expr::Asin(e)) }
pub fn acos(e: E) -> E { E::new(Expr::Acos(e)) }
pub fn atan(e: E) -> E { E::new(Expr::Atan(e)) }
pub fn atan2(y: E, x: E) -> E { E::new(Expr::Atan2(y, x)) }
pub fn sinh(e: E) -> E { E::new(Expr::Sinh(e)) }
pub fn cosh(e: E) -> E { E::new(Expr::Cosh(e)) }
pub fn tanh(e: E) -> E { E::new(Expr::Tanh(e)) }
pub fn exp(e: E) -> E { E::new(Expr::Exp(e)) }
pub fn ln(e: E) -> E { E::new(Expr::Ln(e)) }
pub fn log2(e: E) -> E { E::new(Expr::Log2(e)) }
pub fn log10(e: E) -> E { E::new(Expr::Log10(e)) }
pub fn sqrt(e: E) -> E { E::new(Expr::Sqrt(e)) }
pub fn abs(e: E) -> E { E::new(Expr::Abs(e)) }
pub fn pow(base: E, exponent: E) -> E { E::new(Expr::Pow(base, exponent)).simplify() }
impl std::ops::Add for E {
type Output = E;
fn add(self, rhs: E) -> E {
E::new(Expr::Add(self, rhs)).simplify()
}
}
impl std::ops::Sub for E {
type Output = E;
fn sub(self, rhs: E) -> E {
E::new(Expr::Sub(self, rhs)).simplify()
}
}
impl std::ops::Mul for E {
type Output = E;
fn mul(self, rhs: E) -> E {
E::new(Expr::Mul(self, rhs)).simplify()
}
}
impl std::ops::Div for E {
type Output = E;
fn div(self, rhs: E) -> E {
E::new(Expr::Div(self, rhs)).simplify()
}
}
impl std::ops::Neg for E {
type Output = E;
fn neg(self) -> E {
E::new(Expr::Neg(self)).simplify()
}
}
impl std::ops::Add<f64> for E {
type Output = E;
fn add(self, rhs: f64) -> E { E::new(Expr::Add(self, constant(rhs))).simplify() }
}
impl std::ops::Add<E> for f64 {
type Output = E;
fn add(self, rhs: E) -> E { E::new(Expr::Add(constant(self), rhs)).simplify() }
}
impl std::ops::Sub<f64> for E {
type Output = E;
fn sub(self, rhs: f64) -> E { E::new(Expr::Sub(self, constant(rhs))).simplify() }
}
impl std::ops::Sub<E> for f64 {
type Output = E;
fn sub(self, rhs: E) -> E { E::new(Expr::Sub(constant(self), rhs)).simplify() }
}
impl std::ops::Mul<f64> for E {
type Output = E;
fn mul(self, rhs: f64) -> E { E::new(Expr::Mul(self, constant(rhs))).simplify() }
}
impl std::ops::Mul<E> for f64 {
type Output = E;
fn mul(self, rhs: E) -> E { E::new(Expr::Mul(constant(self), rhs)).simplify() }
}
impl std::ops::Div<f64> for E {
type Output = E;
fn div(self, rhs: f64) -> E { E::new(Expr::Div(self, constant(rhs))).simplify() }
}
impl std::ops::Div<E> for f64 {
type Output = E;
fn div(self, rhs: E) -> E { E::new(Expr::Div(constant(self), rhs)).simplify() }
}
pub use linalg::{SymVec, SymMat, jacobian};
pub use diff::DiffVar;
pub use parse::{parse, ParseError};
pub use geo::{vect2sym, vect3sym, matrix2sym, matrix3sym, quaternsym};
pub use cse::cse;
pub use arael_sym_macros::sym;