1#![allow(clippy::should_implement_trait)]
294
295mod diff;
296mod eval;
297mod fmt;
298mod simplify;
299mod linalg;
300mod parse;
301pub mod geo;
302pub mod cse;
303
304use std::hash::{Hash, Hasher};
305use std::rc::Rc;
306
307#[derive(Clone, PartialEq)]
313pub struct E(Rc<Expr>);
314
315impl Eq for E {}
316
317impl E {
318 fn new(expr: Expr) -> E {
319 E(Rc::new(expr))
320 }
321
322 pub fn symbols(&self) -> std::collections::HashSet<String> {
324 let mut out = std::collections::HashSet::new();
325 self.collect_symbols(&mut out);
326 out
327 }
328
329 fn collect_symbols(&self, out: &mut std::collections::HashSet<String>) {
330 match &*self.0 {
331 Expr::Sym(s) => { out.insert(s.clone()); }
332 Expr::Const(_) | Expr::NamedConst { .. } => {}
333 Expr::Neg(a) | Expr::Sin(a) | Expr::Cos(a) | Expr::Tan(a)
334 | Expr::Asin(a) | Expr::Acos(a) | Expr::Atan(a)
335 | Expr::Sinh(a) | Expr::Cosh(a) | Expr::Tanh(a)
336 | Expr::Exp(a) | Expr::Ln(a) | Expr::Log2(a) | Expr::Log10(a)
337 | Expr::Sqrt(a) | Expr::Abs(a)
338 | Expr::Heaviside(a) => { a.collect_symbols(out); }
339 Expr::Add(a, b) | Expr::Sub(a, b) | Expr::Mul(a, b)
340 | Expr::Div(a, b) | Expr::Pow(a, b) | Expr::Atan2(a, b) => {
341 a.collect_symbols(out);
342 b.collect_symbols(out);
343 }
344 Expr::Clamp(a, b, c) => {
345 a.collect_symbols(out);
346 b.collect_symbols(out);
347 c.collect_symbols(out);
348 }
349 Expr::Func { args, .. } => {
350 for arg in args { arg.collect_symbols(out); }
351 }
352 }
353 }
354
355 pub fn substitute(&self, subs: &[(E, E)]) -> E {
358 for (from, to) in subs {
359 if self == from { return to.clone(); }
360 }
361 match &*self.0 {
362 Expr::Sym(_) | Expr::Const(_) | Expr::NamedConst { .. } => self.clone(),
363 Expr::Neg(a) => -a.substitute(subs),
364 Expr::Add(a, b) => a.substitute(subs) + b.substitute(subs),
365 Expr::Sub(a, b) => a.substitute(subs) - b.substitute(subs),
366 Expr::Mul(a, b) => a.substitute(subs) * b.substitute(subs),
367 Expr::Div(a, b) => a.substitute(subs) / b.substitute(subs),
368 Expr::Pow(a, b) => pow(a.substitute(subs), b.substitute(subs)),
369 Expr::Sin(a) => sin(a.substitute(subs)),
370 Expr::Cos(a) => cos(a.substitute(subs)),
371 Expr::Tan(a) => tan(a.substitute(subs)),
372 Expr::Asin(a) => asin(a.substitute(subs)),
373 Expr::Acos(a) => acos(a.substitute(subs)),
374 Expr::Atan(a) => atan(a.substitute(subs)),
375 Expr::Atan2(a, b) => atan2(a.substitute(subs), b.substitute(subs)),
376 Expr::Sinh(a) => sinh(a.substitute(subs)),
377 Expr::Cosh(a) => cosh(a.substitute(subs)),
378 Expr::Tanh(a) => tanh(a.substitute(subs)),
379 Expr::Exp(a) => exp(a.substitute(subs)),
380 Expr::Ln(a) => ln(a.substitute(subs)),
381 Expr::Log2(a) => log2(a.substitute(subs)),
382 Expr::Log10(a) => ln(a.substitute(subs)) / ln(constant(10.0)),
383 Expr::Sqrt(a) => sqrt(a.substitute(subs)),
384 Expr::Abs(a) => abs(a.substitute(subs)),
385 Expr::Heaviside(a) => heaviside(a.substitute(subs)),
386 Expr::Clamp(a, lo, hi) => clamp(a.substitute(subs), lo.substitute(subs), hi.substitute(subs)),
387 Expr::Func { name, params, kind, args } => {
388 let new_args = args.iter().map(|a| a.substitute(subs)).collect();
389 E::new(Expr::Func { name: name.clone(), params: params.clone(), kind: kind.clone(), args: new_args })
390 }
391 }
392 }
393}
394
395impl std::ops::Deref for E {
396 type Target = Expr;
397 fn deref(&self) -> &Expr {
398 &self.0
399 }
400}
401
402impl AsRef<Expr> for E {
403 fn as_ref(&self) -> &Expr {
404 &self.0
405 }
406}
407
408#[derive(Debug, Clone, PartialEq)]
413pub enum Expr {
414 Sym(String),
416 Const(f64),
418 Neg(E),
420 Add(E, E),
422 Sub(E, E),
424 Mul(E, E),
426 Div(E, E),
428 Pow(E, E),
430 Sin(E),
432 Cos(E),
434 Tan(E),
436 Asin(E),
438 Acos(E),
440 Atan(E),
442 Atan2(E, E),
444 Sinh(E),
446 Cosh(E),
448 Tanh(E),
450 Exp(E),
452 Ln(E),
454 Log2(E),
456 Log10(E),
458 Sqrt(E),
460 Abs(E),
462 Heaviside(E),
464 Clamp(E, E, E),
466 NamedConst {
469 name: String,
470 value: f64,
471 rust_f32: String,
472 rust_f64: String,
473 latex: String,
474 },
475 Func {
477 name: String,
479 params: Vec<String>,
481 kind: FuncKind,
483 args: Vec<E>,
485 },
486}
487
488#[derive(Debug, Clone, PartialEq)]
491#[allow(unpredictable_function_pointer_comparisons)]
492pub enum FuncKind {
493 Symbolic { body: E },
495 SymbolicDerivs { body: E, derivs: Vec<E> },
497 Extern { derivs: Vec<E>, eval_fn: fn(&[f64]) -> f64, call_path: String },
500}
501
502impl FuncKind {
503 pub fn auto_diff_body(&self) -> Option<&E> {
505 match self {
506 FuncKind::Symbolic { body } => Some(body),
507 _ => None,
508 }
509 }
510
511 pub fn derivs(&self) -> Option<&[E]> {
513 match self {
514 FuncKind::SymbolicDerivs { derivs, .. } | FuncKind::Extern { derivs, .. } => Some(derivs),
515 FuncKind::Symbolic { .. } => None,
516 }
517 }
518
519 pub fn body(&self) -> Option<&E> {
521 match self {
522 FuncKind::Symbolic { body } | FuncKind::SymbolicDerivs { body, .. } => Some(body),
523 FuncKind::Extern { .. } => None,
524 }
525 }
526
527 pub fn eval_fn(&self) -> Option<fn(&[f64]) -> f64> {
529 match self {
530 FuncKind::Extern { eval_fn, .. } => Some(*eval_fn),
531 _ => None,
532 }
533 }
534}
535
536impl Hash for FuncKind {
537 fn hash<H: Hasher>(&self, state: &mut H) {
538 std::mem::discriminant(self).hash(state);
539 match self {
540 FuncKind::Symbolic { body } => body.hash(state),
541 FuncKind::SymbolicDerivs { body, derivs } => {
542 body.hash(state);
543 derivs.hash(state);
544 }
545 FuncKind::Extern { derivs, eval_fn, call_path } => {
546 derivs.hash(state);
547 (*eval_fn as usize).hash(state);
548 call_path.hash(state);
549 }
550 }
551 }
552}
553
554impl Eq for Expr {}
555
556impl Hash for Expr {
557 fn hash<H: Hasher>(&self, state: &mut H) {
558 std::mem::discriminant(self).hash(state);
559 match self {
560 Expr::Sym(s) => s.hash(state),
561 Expr::Const(v) => v.to_bits().hash(state),
562 Expr::Neg(a) | Expr::Sin(a) | Expr::Cos(a) | Expr::Tan(a)
563 | Expr::Asin(a) | Expr::Acos(a) | Expr::Atan(a)
564 | Expr::Sinh(a) | Expr::Cosh(a) | Expr::Tanh(a)
565 | Expr::Exp(a) | Expr::Ln(a) | Expr::Log2(a) | Expr::Log10(a)
566 | Expr::Sqrt(a) | Expr::Abs(a)
567 | Expr::Heaviside(a) => a.hash(state),
568 Expr::Add(a, b) | Expr::Sub(a, b) | Expr::Mul(a, b)
569 | Expr::Div(a, b) | Expr::Pow(a, b) | Expr::Atan2(a, b) => {
570 a.hash(state);
571 b.hash(state);
572 }
573 Expr::Clamp(a, b, c) => {
574 a.hash(state);
575 b.hash(state);
576 c.hash(state);
577 }
578 Expr::NamedConst { name, value, .. } => {
579 name.hash(state);
580 value.to_bits().hash(state);
581 }
582 Expr::Func { name, params, kind, args } => {
583 name.hash(state);
584 params.hash(state);
585 kind.hash(state);
586 args.hash(state);
587 }
588 }
589 }
590}
591
592impl Hash for E {
593 fn hash<H: Hasher>(&self, state: &mut H) {
594 self.0.hash(state);
595 }
596}
597
598pub fn symbol(name: &str) -> E {
602 E::new(Expr::Sym(name.to_string()))
603}
604
605pub trait AsVarName {
614 fn var_name(&self) -> &str;
616
617 fn var_expr(&self) -> E {
620 symbol(self.var_name())
621 }
622}
623
624impl AsVarName for &str {
625 fn var_name(&self) -> &str { self }
626}
627
628impl AsVarName for &&str {
629 fn var_name(&self) -> &str { self }
630}
631
632impl AsVarName for str {
633 fn var_name(&self) -> &str { self }
634}
635
636impl AsVarName for String {
637 fn var_name(&self) -> &str { self.as_str() }
638}
639
640impl AsVarName for &String {
641 fn var_name(&self) -> &str { self.as_str() }
642}
643
644impl AsVarName for &E {
645 fn var_name(&self) -> &str { (*self).var_name() }
646 fn var_expr(&self) -> E { (*self).clone() }
647}
648
649impl AsVarName for E {
650 fn var_name(&self) -> &str {
651 match self.as_ref() {
652 Expr::Sym(name) => name.as_str(),
653 _ => panic!("AsVarName::var_name: expected a symbol, got `{self}`"),
654 }
655 }
656 fn var_expr(&self) -> E { self.clone() }
657}
658
659#[macro_export]
674macro_rules! symbols {
675 ($($name:ident),+ $(,)?) => {
676 ( $( $crate::symbol(stringify!($name)) ),+ , )
677 };
678}
679
680pub fn constant(val: f64) -> E {
682 E::new(Expr::Const(val))
683}
684
685impl From<f64> for E {
686 fn from(v: f64) -> E { constant(v) }
687}
688
689impl From<i64> for E {
690 fn from(v: i64) -> E { constant(v as f64) }
691}
692
693impl From<i32> for E {
694 fn from(v: i32) -> E { constant(v as f64) }
695}
696
697pub fn named_const(name: &str, value: f64, rust_f32: &str, rust_f64: &str, latex: &str) -> E {
699 E::new(Expr::NamedConst {
700 name: name.to_string(), value,
701 rust_f32: rust_f32.to_string(), rust_f64: rust_f64.to_string(),
702 latex: latex.to_string(),
703 })
704}
705
706pub fn pi() -> E {
708 named_const("pi", std::f64::consts::PI,
709 "std::f32::consts::PI", "std::f64::consts::PI", "\\pi")
710}
711
712pub fn epsilon() -> E {
714 named_const("epsilon", f64::EPSILON,
715 "f32::EPSILON", "f64::EPSILON", "\\epsilon")
716}
717
718pub fn euler() -> E {
720 named_const("e", std::f64::consts::E,
721 "std::f32::consts::E", "std::f64::consts::E", "e")
722}
723
724pub fn c(val: f64) -> E { constant(val) }
726
727pub fn sin(e: E) -> E { E::new(Expr::Sin(e)) }
729pub fn cos(e: E) -> E { E::new(Expr::Cos(e)) }
731pub fn tan(e: E) -> E { E::new(Expr::Tan(e)) }
733pub fn asin(e: E) -> E { E::new(Expr::Asin(e)) }
735pub fn acos(e: E) -> E { E::new(Expr::Acos(e)) }
737pub fn atan(e: E) -> E { E::new(Expr::Atan(e)) }
739pub fn atan2(y: E, x: E) -> E { E::new(Expr::Atan2(y, x)) }
741pub fn sinh(e: E) -> E { E::new(Expr::Sinh(e)) }
743pub fn cosh(e: E) -> E { E::new(Expr::Cosh(e)) }
745pub fn tanh(e: E) -> E { E::new(Expr::Tanh(e)) }
747pub fn exp(e: E) -> E { E::new(Expr::Exp(e)) }
749pub fn ln(e: E) -> E { E::new(Expr::Ln(e)) }
751pub fn log2(e: E) -> E { E::new(Expr::Log2(e)) }
753pub fn log10(e: E) -> E { E::new(Expr::Log10(e)) }
755pub fn sqrt(e: E) -> E { E::new(Expr::Sqrt(e)) }
757pub fn abs(e: E) -> E { E::new(Expr::Abs(e)) }
759pub fn heaviside(e: E) -> E { E::new(Expr::Heaviside(e)) }
761pub fn clamp(val: impl Into<E>, lo: impl Into<E>, hi: impl Into<E>) -> E {
765 E::new(Expr::Clamp(val.into(), lo.into(), hi.into()))
766}
767pub fn pow(base: impl Into<E>, exponent: impl Into<E>) -> E {
771 E::new(Expr::Pow(base.into(), exponent.into())).simplify()
772}
773
774#[derive(Clone, Copy)]
788pub enum FunctionRef {
789 Unary(fn(E) -> E),
790 Binary(fn(E, E) -> E),
791 Ternary(fn(E, E, E) -> E),
792}
793
794pub const FUNCTIONS: &[(&str, FunctionRef)] = &[
799 ("sin", FunctionRef::Unary(sin)),
801 ("cos", FunctionRef::Unary(cos)),
802 ("tan", FunctionRef::Unary(tan)),
803 ("asin", FunctionRef::Unary(asin)),
804 ("acos", FunctionRef::Unary(acos)),
805 ("atan", FunctionRef::Unary(atan)),
806 ("sinh", FunctionRef::Unary(sinh)),
807 ("cosh", FunctionRef::Unary(cosh)),
808 ("tanh", FunctionRef::Unary(tanh)),
809 ("exp", FunctionRef::Unary(exp)),
811 ("ln", FunctionRef::Unary(ln)),
812 ("log2", FunctionRef::Unary(log2)),
813 ("log10", FunctionRef::Unary(log10)),
814 ("sqrt", FunctionRef::Unary(sqrt)),
815 ("abs", FunctionRef::Unary(abs)),
816 ("heaviside", FunctionRef::Unary(heaviside)),
817 ("identity", FunctionRef::Unary(identity)),
819 ("safe_sqrt", FunctionRef::Unary(safe_sqrt)),
820 ("safe_asin", FunctionRef::Unary(safe_asin)),
821 ("safe_acos", FunctionRef::Unary(safe_acos)),
822 ("atan2", FunctionRef::Binary(atan2)),
824 ("pow", FunctionRef::Binary(pow)),
825 ("safe_atan2", FunctionRef::Binary(safe_atan2)),
826 ("rad_diff", FunctionRef::Binary(rad_diff)),
827 ("rad_sum", FunctionRef::Binary(rad_sum)),
828 ("clamp", FunctionRef::Ternary(clamp)),
830];
831
832pub fn function_by_name(name: &str) -> Option<FunctionRef> {
836 FUNCTIONS.iter().find(|(n, _)| *n == name).map(|(_, f)| *f)
837}
838
839pub fn function_names() -> impl Iterator<Item = &'static str> {
842 FUNCTIONS.iter().map(|(n, _)| *n)
843}
844
845#[derive(Clone)]
907pub struct FunctionBag {
908 table: std::collections::HashMap<String, BagFunction>,
911}
912
913#[derive(Clone)]
914struct BagFunction {
915 params: std::vec::Vec<String>,
916 kind: FuncKind,
917}
918
919impl Default for FunctionBag {
920 fn default() -> Self { Self::new() }
921}
922
923fn extract_func_template(e: E, source: &str) -> Result<(String, std::vec::Vec<String>, FuncKind), String> {
924 match (*e.0).clone() {
925 Expr::Func { name, params, kind, .. } => Ok((name, params, kind)),
926 _ => Err(format!("{source}: expected Expr::Func, got a different expression")),
927 }
928}
929
930impl FunctionBag {
931 pub fn new() -> Self {
934 Self { table: std::collections::HashMap::new() }
935 }
936
937 pub fn add(&mut self, e: E) -> Result<(), String> {
948 let (name, params, kind) = extract_func_template(e, "FunctionBag::add")?;
949 self.table.insert(name, BagFunction { params, kind });
950 Ok(())
951 }
952
953 pub fn add1<F>(&mut self, f: F) -> Result<(), String>
960 where F: FnOnce(E) -> E
961 {
962 let e = f(symbol("__a0"));
963 let (name, params, kind) = extract_func_template(e, "FunctionBag::add1")?;
964 self.table.insert(name, BagFunction { params, kind });
965 Ok(())
966 }
967
968 pub fn add2<F>(&mut self, f: F) -> Result<(), String>
975 where F: FnOnce(E, E) -> E
976 {
977 let e = f(symbol("__a0"), symbol("__a1"));
978 let (name, params, kind) = extract_func_template(e, "FunctionBag::add2")?;
979 self.table.insert(name, BagFunction { params, kind });
980 Ok(())
981 }
982
983 #[allow(non_snake_case)]
995 pub fn addN<F>(&mut self, arity: usize, f: F) -> Result<(), String>
996 where F: FnOnce(std::vec::Vec<E>) -> E
997 {
998 let placeholders: std::vec::Vec<E> =
999 (0..arity).map(|i| symbol(&format!("__a{i}"))).collect();
1000 let e = f(placeholders);
1001 let (name, params, kind) = extract_func_template(e, "FunctionBag::addN")?;
1002 self.table.insert(name, BagFunction { params, kind });
1003 Ok(())
1004 }
1005
1006 pub fn add_symbolic(&mut self, name: impl Into<String>, params: std::vec::Vec<String>, body: E) {
1011 self.table.insert(
1012 name.into(),
1013 BagFunction { params, kind: FuncKind::Symbolic { body } },
1014 );
1015 }
1016
1017 pub fn add_with_kind(
1022 &mut self,
1023 name: impl Into<String>,
1024 params: std::vec::Vec<String>,
1025 kind: FuncKind,
1026 ) {
1027 self.table.insert(name.into(), BagFunction { params, kind });
1028 }
1029
1030 pub fn remove(&mut self, name: &str) -> bool {
1033 self.table.remove(name).is_some()
1034 }
1035
1036 pub fn contains(&self, name: &str) -> bool {
1039 self.table.contains_key(name)
1040 }
1041
1042 pub fn names(&self) -> std::vec::Vec<String> {
1044 self.table.keys().cloned().collect()
1045 }
1046
1047 pub fn entries(&self) -> impl Iterator<Item = (&str, usize)> {
1050 self.table.iter().map(|(k, v)| (k.as_str(), v.params.len()))
1051 }
1052
1053 pub fn get_info(&self, name: &str) -> Option<(&[String], &FuncKind)> {
1057 let f = self.table.get(name)?;
1058 Some((&f.params, &f.kind))
1059 }
1060
1061 pub fn call(&self, name: &str, args: &[E]) -> Option<Result<E, String>> {
1069 let f = self.table.get(name)?;
1070 if args.len() != f.params.len() {
1071 return Some(Err(format!(
1072 "{} expects {} argument(s), got {}",
1073 name, f.params.len(), args.len()
1074 )));
1075 }
1076 let func = E::new(Expr::Func {
1077 name: name.to_string(),
1078 params: f.params.clone(),
1079 kind: f.kind.clone(),
1080 args: args.to_vec(),
1081 });
1082 Some(Ok(func))
1083 }
1084}
1085
1086impl std::ops::Add for E {
1089 type Output = E;
1090 fn add(self, rhs: E) -> E {
1091 E::new(Expr::Add(self, rhs)).simplify()
1092 }
1093}
1094
1095impl std::ops::Sub for E {
1096 type Output = E;
1097 fn sub(self, rhs: E) -> E {
1098 E::new(Expr::Sub(self, rhs)).simplify()
1099 }
1100}
1101
1102impl std::ops::Mul for E {
1103 type Output = E;
1104 fn mul(self, rhs: E) -> E {
1105 E::new(Expr::Mul(self, rhs)).simplify()
1106 }
1107}
1108
1109impl std::ops::Div for E {
1110 type Output = E;
1111 fn div(self, rhs: E) -> E {
1112 E::new(Expr::Div(self, rhs)).simplify()
1113 }
1114}
1115
1116impl std::ops::Neg for E {
1117 type Output = E;
1118 fn neg(self) -> E {
1119 E::new(Expr::Neg(self)).simplify()
1120 }
1121}
1122
1123impl std::ops::Add<f64> for E {
1126 type Output = E;
1127 fn add(self, rhs: f64) -> E { E::new(Expr::Add(self, constant(rhs))).simplify() }
1128}
1129
1130impl std::ops::Add<E> for f64 {
1131 type Output = E;
1132 fn add(self, rhs: E) -> E { E::new(Expr::Add(constant(self), rhs)).simplify() }
1133}
1134
1135impl std::ops::Sub<f64> for E {
1136 type Output = E;
1137 fn sub(self, rhs: f64) -> E { E::new(Expr::Sub(self, constant(rhs))).simplify() }
1138}
1139
1140impl std::ops::Sub<E> for f64 {
1141 type Output = E;
1142 fn sub(self, rhs: E) -> E { E::new(Expr::Sub(constant(self), rhs)).simplify() }
1143}
1144
1145impl std::ops::Mul<f64> for E {
1146 type Output = E;
1147 fn mul(self, rhs: f64) -> E { E::new(Expr::Mul(self, constant(rhs))).simplify() }
1148}
1149
1150impl std::ops::Mul<E> for f64 {
1151 type Output = E;
1152 fn mul(self, rhs: E) -> E { E::new(Expr::Mul(constant(self), rhs)).simplify() }
1153}
1154
1155impl std::ops::Div<f64> for E {
1156 type Output = E;
1157 fn div(self, rhs: f64) -> E { E::new(Expr::Div(self, constant(rhs))).simplify() }
1158}
1159
1160impl std::ops::Div<E> for f64 {
1161 type Output = E;
1162 fn div(self, rhs: E) -> E { E::new(Expr::Div(constant(self), rhs)).simplify() }
1163}
1164
1165impl std::ops::Add<i64> for E {
1176 type Output = E;
1177 fn add(self, rhs: i64) -> E { E::new(Expr::Add(self, constant(rhs as f64))).simplify() }
1178}
1179
1180impl std::ops::Add<E> for i64 {
1181 type Output = E;
1182 fn add(self, rhs: E) -> E { E::new(Expr::Add(constant(self as f64), rhs)).simplify() }
1183}
1184
1185impl std::ops::Sub<i64> for E {
1186 type Output = E;
1187 fn sub(self, rhs: i64) -> E { E::new(Expr::Sub(self, constant(rhs as f64))).simplify() }
1188}
1189
1190impl std::ops::Sub<E> for i64 {
1191 type Output = E;
1192 fn sub(self, rhs: E) -> E { E::new(Expr::Sub(constant(self as f64), rhs)).simplify() }
1193}
1194
1195impl std::ops::Mul<i64> for E {
1196 type Output = E;
1197 fn mul(self, rhs: i64) -> E { E::new(Expr::Mul(self, constant(rhs as f64))).simplify() }
1198}
1199
1200impl std::ops::Mul<E> for i64 {
1201 type Output = E;
1202 fn mul(self, rhs: E) -> E { E::new(Expr::Mul(constant(self as f64), rhs)).simplify() }
1203}
1204
1205impl std::ops::Div<i64> for E {
1206 type Output = E;
1207 fn div(self, rhs: i64) -> E { E::new(Expr::Div(self, constant(rhs as f64))).simplify() }
1208}
1209
1210impl std::ops::Div<E> for i64 {
1211 type Output = E;
1212 fn div(self, rhs: E) -> E { E::new(Expr::Div(constant(self as f64), rhs)).simplify() }
1213}
1214
1215pub(crate) fn expand_func(params: &[String], body: &E, args: &[E]) -> E {
1219 let mut expanded = body.clone();
1220 for (p, a) in params.iter().zip(args.iter()) {
1221 expanded = expanded.subs(p, a);
1222 }
1223 expanded
1224}
1225
1226pub fn simple_func1(name: &str, body: impl Fn(E) -> E) -> impl Fn(E) -> E + Clone {
1240 let name = name.to_string();
1241 let body = body(symbol("__p0"));
1242 move |arg: E| {
1243 E::new(Expr::Func {
1244 name: name.clone(),
1245 params: vec!["__p0".to_string()],
1246 kind: FuncKind::Symbolic { body: body.clone() },
1247 args: vec![arg],
1248 })
1249 }
1250}
1251
1252pub fn simple_func2(name: &str, body: impl Fn(E, E) -> E) -> impl Fn(E, E) -> E + Clone {
1255 let name = name.to_string();
1256 let body = body(symbol("__p0"), symbol("__p1"));
1257 move |a: E, b: E| {
1258 E::new(Expr::Func {
1259 name: name.clone(),
1260 params: vec!["__p0".to_string(), "__p1".to_string()],
1261 kind: FuncKind::Symbolic { body: body.clone() },
1262 args: vec![a, b],
1263 })
1264 }
1265}
1266
1267pub fn simple_func(name: &str, arity: usize, body: impl Fn(Vec<E>) -> E) -> impl Fn(Vec<E>) -> E + Clone {
1270 let name = name.to_string();
1271 let params: Vec<String> = (0..arity).map(|i| format!("__p{}", i)).collect();
1272 let syms: Vec<E> = params.iter().map(|p| symbol(p)).collect();
1273 let body = body(syms);
1274 move |args: Vec<E>| {
1275 assert_eq!(args.len(), arity,
1276 "custom function '{}' expects {} args, got {}", name, arity, args.len());
1277 E::new(Expr::Func {
1278 name: name.clone(),
1279 params: params.clone(),
1280 kind: FuncKind::Symbolic { body: body.clone() },
1281 args,
1282 })
1283 }
1284}
1285
1286pub fn simple_func1_derivs(
1289 name: &str, body: impl Fn(E) -> E, derivs: impl Fn(E) -> [E; 1],
1290) -> impl Fn(E) -> E + Clone {
1291 let name = name.to_string();
1292 let p0 = symbol("__p0");
1293 let body = body(p0.clone());
1294 let d = derivs(p0);
1295 move |a: E| {
1296 E::new(Expr::Func {
1297 name: name.clone(),
1298 params: vec!["__p0".to_string()],
1299 kind: FuncKind::SymbolicDerivs { body: body.clone(), derivs: vec![d[0].clone()] },
1300 args: vec![a],
1301 })
1302 }
1303}
1304
1305pub fn simple_func2_derivs(
1319 name: &str, body: impl Fn(E, E) -> E, derivs: impl Fn(E, E) -> [E; 2],
1320) -> impl Fn(E, E) -> E + Clone {
1321 let name = name.to_string();
1322 let p0 = symbol("__p0");
1323 let p1 = symbol("__p1");
1324 let body = body(p0.clone(), p1.clone());
1325 let d = derivs(p0, p1);
1326 move |a: E, b: E| {
1327 E::new(Expr::Func {
1328 name: name.clone(),
1329 params: vec!["__p0".to_string(), "__p1".to_string()],
1330 kind: FuncKind::SymbolicDerivs { body: body.clone(), derivs: vec![d[0].clone(), d[1].clone()] },
1331 args: vec![a, b],
1332 })
1333 }
1334}
1335
1336pub fn simple_func_derivs(
1339 name: &str, arity: usize, body: impl Fn(Vec<E>) -> E, derivs: impl Fn(Vec<E>) -> Vec<E>,
1340) -> impl Fn(Vec<E>) -> E + Clone {
1341 let name = name.to_string();
1342 let params: Vec<String> = (0..arity).map(|i| format!("__p{}", i)).collect();
1343 let syms: Vec<E> = params.iter().map(|p| symbol(p)).collect();
1344 let body = body(syms.clone());
1345 let d = derivs(syms);
1346 assert_eq!(d.len(), arity, "derivs must return {} elements", arity);
1347 move |args: Vec<E>| {
1348 assert_eq!(args.len(), arity,
1349 "function '{}' expects {} args, got {}", name, arity, args.len());
1350 E::new(Expr::Func {
1351 name: name.clone(),
1352 params: params.clone(),
1353 kind: FuncKind::SymbolicDerivs { body: body.clone(), derivs: d.clone() },
1354 args,
1355 })
1356 }
1357}
1358
1359pub fn extern_func1(
1362 name: &str, call_path: &str,
1363 derivs: impl Fn(E) -> [E; 1],
1364 eval_fn: fn(&[f64]) -> f64,
1365) -> impl Fn(E) -> E + Clone {
1366 let name = name.to_string();
1367 let call_path = call_path.to_string();
1368 let d = derivs(symbol("__p0"));
1369 move |a: E| {
1370 E::new(Expr::Func {
1371 name: name.clone(),
1372 params: vec!["__p0".to_string()],
1373 kind: FuncKind::Extern {
1374 derivs: vec![d[0].clone()],
1375 eval_fn,
1376 call_path: call_path.clone(),
1377 },
1378 args: vec![a],
1379 })
1380 }
1381}
1382
1383pub fn extern_func2(
1401 name: &str, call_path: &str,
1402 derivs: impl Fn(E, E) -> [E; 2],
1403 eval_fn: fn(&[f64]) -> f64,
1404) -> impl Fn(E, E) -> E + Clone {
1405 let name = name.to_string();
1406 let call_path = call_path.to_string();
1407 let d = derivs(symbol("__p0"), symbol("__p1"));
1408 move |a: E, b: E| {
1409 E::new(Expr::Func {
1410 name: name.clone(),
1411 params: vec!["__p0".to_string(), "__p1".to_string()],
1412 kind: FuncKind::Extern {
1413 derivs: vec![d[0].clone(), d[1].clone()],
1414 eval_fn,
1415 call_path: call_path.clone(),
1416 },
1417 args: vec![a, b],
1418 })
1419 }
1420}
1421
1422pub fn extern_func(
1425 name: &str, arity: usize, call_path: &str,
1426 derivs: impl Fn(Vec<E>) -> Vec<E>,
1427 eval_fn: fn(&[f64]) -> f64,
1428) -> impl Fn(Vec<E>) -> E + Clone {
1429 let name = name.to_string();
1430 let call_path = call_path.to_string();
1431 let params: Vec<String> = (0..arity).map(|i| format!("__p{}", i)).collect();
1432 let syms: Vec<E> = params.iter().map(|p| symbol(p)).collect();
1433 let d = derivs(syms);
1434 assert_eq!(d.len(), arity, "derivs must return {} elements", arity);
1435 move |args: Vec<E>| {
1436 assert_eq!(args.len(), arity,
1437 "extern function '{}' expects {} args, got {}", name, arity, args.len());
1438 E::new(Expr::Func {
1439 name: name.clone(),
1440 params: params.clone(),
1441 kind: FuncKind::Extern {
1442 derivs: d.clone(),
1443 eval_fn,
1444 call_path: call_path.clone(),
1445 },
1446 args,
1447 })
1448 }
1449}
1450
1451pub fn grad1(body: impl Fn(E) -> E) -> impl Fn(E) -> [E; 1] + Clone {
1454 let p = symbol("__g0");
1455 let d = body(p).diff("__g0");
1456 move |a: E| { [d.subs("__g0", &a)] }
1457}
1458
1459pub fn grad2(body: impl Fn(E, E) -> E) -> impl Fn(E, E) -> [E; 2] + Clone {
1462 let p0 = symbol("__g0");
1463 let p1 = symbol("__g1");
1464 let expr = body(p0, p1);
1465 let d0 = expr.diff("__g0");
1466 let d1 = expr.diff("__g1");
1467 move |a: E, b: E| {
1468 [d0.subs("__g0", &a).subs("__g1", &b),
1469 d1.subs("__g0", &a).subs("__g1", &b)]
1470 }
1471}
1472
1473fn rad2rad(v: f64) -> f64 {
1475 use std::f64::consts::PI;
1476 if !(-PI..=PI).contains(&v) {
1477 v - (2.0 * PI) * (v / (2.0 * PI) + 0.5).floor()
1478 } else {
1479 v
1480 }
1481}
1482
1483pub fn rad_diff(a: E, b: E) -> E {
1487 extern_func2("rad_diff", "arael::utils::rad_diff",
1488 grad2(|a, b| a - b),
1489 |args: &[f64]| rad2rad(args[0] - args[1]))(a, b)
1490}
1491
1492pub fn rad_sum(a: E, b: E) -> E {
1496 extern_func2("rad_sum", "arael::utils::rad_sum",
1497 grad2(|a, b| a + b),
1498 |args: &[f64]| rad2rad(args[0] + args[1]))(a, b)
1499}
1500
1501pub fn identity(x: E) -> E {
1511 simple_func1("identity", |t| t)(x)
1512}
1513
1514pub fn safe_atan2(y: E, x: E) -> E {
1523 simple_func2_derivs("safe_atan2",
1524 atan2,
1525 |y, x| {
1526 let eps2 = epsilon() * epsilon();
1527 let d = x.clone()*x.clone() + y.clone()*y.clone() + eps2;
1528 [x / d.clone(), -y / d]
1529 })(y, x)
1530}
1531
1532pub fn safe_asin(x: E) -> E {
1541 simple_func1_derivs("safe_asin",
1542 |x| asin(clamp(x, c(-1.0), c(1.0))),
1543 |x| {
1548 let xc = clamp(x, c(-1.0), c(1.0));
1549 [c(1.0) / sqrt(identity(c(1.0) - xc.clone()*xc) + epsilon()*epsilon())]
1550 }
1551 )(x)
1552}
1553
1554pub fn safe_acos(x: E) -> E {
1560 simple_func1_derivs("safe_acos",
1561 |x| acos(clamp(x, c(-1.0), c(1.0))),
1562 |x| {
1565 let xc = clamp(x, c(-1.0), c(1.0));
1566 [-c(1.0) / sqrt(identity(c(1.0) - xc.clone()*xc) + epsilon()*epsilon())]
1567 }
1568 )(x)
1569}
1570
1571pub fn safe_sqrt(x: E) -> E {
1581 extern_func1("safe_sqrt", "arael::utils::safe_sqrt",
1582 |x| [c(0.5) / sqrt(identity(x.clone() * heaviside(x)) + epsilon()*epsilon())],
1587 |args| {
1588 let v = args[0];
1589 if v <= 0.0 { 0.0 } else { v.sqrt() }
1590 }
1591 )(x)
1592}
1593
1594pub use linalg::{SymVec, SymMat, jacobian};
1596pub use parse::{parse, parse_with_functions, ParseError};
1597pub use geo::{vect2sym, vect3sym, matrix2sym, matrix3sym, quaternsym};
1598pub use cse::cse;
1599pub use arael_sym_macros::sym;
1600
1601#[cfg(test)]
1602mod tests {
1603 use super::*;
1604 use std::collections::HashMap;
1605
1606 #[test]
1607 fn simple_func_identity_display() {
1608 sym! {
1609 let identity = simple_func1("identity", |t| t);
1610 let x = symbol("x");
1611 assert_eq!(format!("{}", identity(x)), "identity(x)");
1612 }
1613 }
1614
1615 #[test]
1616 fn simple_func_identity_diff() {
1617 sym! {
1618 let identity = simple_func1("identity", |t| t);
1619 let x = symbol("x");
1620 let f = identity(x);
1621 assert_eq!(format!("{}", f.diff("x")), "1");
1622 }
1623 }
1624
1625 #[test]
1626 fn simple_func_identity_chain_rule() {
1627 sym! {
1628 let identity = simple_func1("identity", |t| t);
1629 let x = symbol("x");
1630 let f = identity(x * x);
1631 assert_eq!(format!("{}", f.diff("x")), "2 * x");
1632 }
1633 }
1634
1635 #[test]
1636 fn simple_func_identity_eval() {
1637 sym! {
1638 let identity = simple_func1("identity", |t| t);
1639 let x = symbol("x");
1640 let f = identity(x);
1641 let vars = HashMap::from([("x", 5.0)]);
1642 assert_eq!(f.eval(&vars).unwrap(), 5.0);
1643 }
1644 }
1645
1646 #[test]
1647 fn simple_func_square() {
1648 sym! {
1649 let square = simple_func1("square", |t| t * t);
1650 let x = symbol("x");
1651 let f = square(x + 1.0);
1652 assert_eq!(format!("{}", f), "square(x + 1)");
1653 assert_eq!(format!("{}", f.diff("x")), "2 * (x + 1)");
1654 }
1655 }
1656
1657 #[test]
1658 fn simple_func_square_eval() {
1659 sym! {
1660 let square = simple_func1("square", |t| t * t);
1661 let x = symbol("x");
1662 let f = square(x);
1663 let vars = HashMap::from([("x", 4.0)]);
1664 assert_eq!(f.eval(&vars).unwrap(), 16.0);
1665 }
1666 }
1667
1668 #[test]
1669 fn simple_func_binary() {
1670 sym! {
1671 let f = simple_func2("prod", |a, b| a * b);
1672 let x = symbol("x");
1673 let y = symbol("y");
1674 let result = f(x, y);
1675 assert_eq!(format!("{}", result), "prod(x, y)");
1676 assert_eq!(format!("{}", result.diff("x")), "y");
1677 assert_eq!(format!("{}", result.diff("y")), "x");
1678 }
1679 }
1680
1681 #[test]
1682 fn simple_func_nested() {
1683 sym! {
1684 let identity = simple_func1("identity", |t| t);
1685 let square = simple_func1("square", |t| t * t);
1686 let x = symbol("x");
1687 let f = identity(square(x));
1688 assert_eq!(format!("{}", f), "identity(square(x))");
1689 assert_eq!(format!("{}", f.diff("x")), "2 * x");
1690 }
1691 }
1692
1693 #[test]
1694 fn simple_func_my_sin() {
1695 sym! {
1696 let my_sin = simple_func1("my_sin", |t| sin(t));
1697 let x = symbol("x");
1698 let f = my_sin(x);
1699 assert_eq!(format!("{}", f), "my_sin(x)");
1700 assert_eq!(format!("{}", f.diff("x")), "cos(x)");
1701 }
1702 }
1703
1704 #[test]
1705 fn simple_func_my_sin_chain_rule() {
1706 sym! {
1707 let my_sin = simple_func1("my_sin", |t| sin(t));
1708 let x = symbol("x");
1709 let f = my_sin(x * x);
1710 assert_eq!(format!("{}", f.diff("x")), "2 * x * cos(x^2)");
1711 }
1712 }
1713
1714 #[test]
1715 fn simple_func_to_rust() {
1716 sym! {
1717 let identity = simple_func1("identity", |t| t);
1718 let x = symbol("x");
1719 let f = identity(x);
1720 assert_eq!(f.to_rust("f64"), "x");
1721 }
1722 }
1723
1724 #[test]
1725 fn simple_func_latex() {
1726 sym! {
1727 let identity = simple_func1("identity", |t| t);
1728 let x = symbol("x");
1729 let f = identity(x);
1730 assert_eq!(f.to_latex(), "\\operatorname{identity}\\left(x\\right)");
1731 }
1732 }
1733
1734 #[test]
1735 fn simple_func_free_vars() {
1736 sym! {
1737 let identity = simple_func1("identity", |t| t);
1738 let x = symbol("x");
1739 let f = identity(x + symbol("y"));
1740 let vars = f.free_vars();
1741 assert!(vars.contains("x"));
1742 assert!(vars.contains("y"));
1743 assert!(!vars.contains("t"));
1744 }
1745 }
1746
1747 #[test]
1748 fn simple_func_subs() {
1749 sym! {
1750 let identity = simple_func1("identity", |t| t);
1751 let x = symbol("x");
1752 let f = identity(x);
1753 let g = f.subs("x", &constant(3.0));
1754 assert_eq!(format!("{}", g), "identity(3)");
1755 }
1756 }
1757
1758 #[test]
1759 fn simple_func_simplify_constants() {
1760 sym! {
1761 let square = simple_func1("square", |t| t * t);
1762 let f = square(constant(3.0));
1763 let s = f.simplify();
1764 assert_eq!(format!("{}", s), "9");
1765 }
1766 }
1767
1768 #[test]
1769 fn simple_func_nary() {
1770 sym! {
1771 let f = simple_func("triple_sum", 3, |v| v[0].clone() + v[1].clone() + v[2].clone());
1772 let x = symbol("x");
1773 let y = symbol("y");
1774 let z = symbol("z");
1775 let result = f(vec![x, y, z]);
1776 assert_eq!(format!("{}", result), "triple_sum(x, y, z)");
1777 assert_eq!(format!("{}", result.diff("x")), "1");
1778 }
1779 }
1780
1781 #[test]
1782 fn simple_func_expand() {
1783 sym! {
1784 let square = simple_func1("square", |t| t * t);
1785 let x = symbol("x");
1786 let f = square(x + 1.0);
1787 let expanded = f.expand();
1788 assert_eq!(format!("{}", expanded), "x^2 + 2 * x + 1");
1789 }
1790 }
1791
1792 #[test]
1795 fn simple_func_derivs_codegen() {
1796 sym! {
1798 let f = simple_func1_derivs("inv", |t| 1.0 / t, |t| [-1.0 / (t * t)]);
1799 let x = symbol("x");
1800 assert_eq!(f(x).to_rust("f64"), "1.0_f64 / x");
1801 }
1802 }
1803
1804 #[test]
1807 fn safe_atan2_diff() {
1808 sym! {
1809 let a = symbol("a");
1810 let b = symbol("b");
1811 let f = safe_atan2(a, b);
1812 let da = f.diff("a");
1813 let vars = HashMap::from([("a", 1.0), ("b", 1.0)]);
1814 let v = da.eval(&vars).unwrap();
1815 assert!((v - 0.5).abs() < 1e-10, "d/da at (1,1) = {}, expected 0.5", v);
1816 }
1817 }
1818
1819 #[test]
1820 fn safe_atan2_eval() {
1821 sym! {
1822 let a = symbol("a");
1823 let b = symbol("b");
1824 let f = safe_atan2(a, b);
1825 let vars = HashMap::from([("a", 1.0), ("b", 1.0)]);
1826 let v = f.eval(&vars).unwrap();
1827 assert!((v - std::f64::consts::FRAC_PI_4).abs() < 1e-10);
1828 }
1829 }
1830
1831 #[test]
1832 fn safe_atan2_chain_rule() {
1833 sym! {
1834 let t = symbol("t");
1835 let f = safe_atan2(sin(t), cos(t));
1836 let df = f.diff("t");
1837 let vars = HashMap::from([("t", 0.5)]);
1838 let v = df.eval(&vars).unwrap();
1839 assert!((v - 1.0).abs() < 1e-8, "df/dt at t=0.5 = {}, expected 1", v);
1840 }
1841 }
1842
1843 #[test]
1844 fn safe_atan2_at_zero() {
1845 sym! {
1846 let a = symbol("a");
1847 let b = symbol("b");
1848 let da = safe_atan2(a, b).diff("a");
1849 let vars = HashMap::from([("a", 0.0), ("b", 0.0)]);
1850 let v = da.eval(&vars).unwrap();
1851 assert!(v.is_finite(), "derivative at (0,0) should be finite, got {}", v);
1852 }
1853 }
1854
1855 #[test]
1856 fn safe_asin_eval() {
1857 sym! {
1858 let x = symbol("x");
1859 let f = safe_asin(x);
1860 let vars = HashMap::from([("x", 0.5)]);
1862 assert!((f.eval(&vars).unwrap() - 0.5_f64.asin()).abs() < 1e-10);
1863 let vars = HashMap::from([("x", 1.5)]);
1865 assert!((f.eval(&vars).unwrap() - std::f64::consts::FRAC_PI_2).abs() < 1e-10);
1866 }
1867 }
1868
1869 #[test]
1870 fn safe_asin_deriv_finite() {
1871 sym! {
1872 let x = symbol("x");
1873 let da = safe_asin(x).diff("x");
1874 let vars = HashMap::from([("x", 1.0)]);
1876 let v = da.eval(&vars).unwrap();
1877 assert!(v.is_finite(), "safe_asin derivative at 1.0 should be finite, got {}", v);
1878 }
1879 }
1880
1881 #[test]
1882 fn safe_acos_eval() {
1883 sym! {
1884 let x = symbol("x");
1885 let f = safe_acos(x);
1886 let vars = HashMap::from([("x", 0.5)]);
1887 assert!((f.eval(&vars).unwrap() - 0.5_f64.acos()).abs() < 1e-10);
1888 let vars = HashMap::from([("x", -1.5)]);
1890 assert!((f.eval(&vars).unwrap() - std::f64::consts::PI).abs() < 1e-10);
1891 }
1892 }
1893
1894 #[test]
1895 fn identity_codegen_parens() {
1896 sym! {
1897 let x = symbol("x");
1898 let f = identity(c(1.0) - x * x) + epsilon * epsilon;
1899 let code = f.to_rust("f64");
1900 assert!(code.contains("(-x.powf(2.0_f64) + 1.0_f64)"),
1902 "expected parens around identity body, got: {}", code);
1903 }
1904 }
1905
1906 #[test]
1907 fn identity_diff() {
1908 sym! {
1909 let x = symbol("x");
1910 let f = identity(x * x);
1911 assert_eq!(format!("{}", f.diff("x")), "2 * x");
1912 }
1913 }
1914
1915 #[test]
1916 fn safe_acos_deriv_finite() {
1917 sym! {
1918 let x = symbol("x");
1919 let da = safe_acos(x).diff("x");
1920 let vars = HashMap::from([("x", 1.0)]);
1921 let v = da.eval(&vars).unwrap();
1922 assert!(v.is_finite(), "safe_acos derivative at 1.0 should be finite, got {}", v);
1923 }
1924 }
1925
1926 #[test]
1933 fn safe_derivs_finite_outside_domain() {
1934 sym! {
1935 let x = symbol("x");
1936 let d_asin = safe_asin(x).diff("x");
1937 let d_acos = safe_acos(x).diff("x");
1938 let d_sqrt = safe_sqrt(x).diff("x");
1939 for v in [-5.0_f64, -1.5, 1.5, 5.0] {
1940 let vars = HashMap::from([("x", v)]);
1941 let a = d_asin.eval(&vars).unwrap();
1942 let c = d_acos.eval(&vars).unwrap();
1943 assert!(a.is_finite(), "safe_asin'({}) should be finite, got {}", v, a);
1944 assert!(c.is_finite(), "safe_acos'({}) should be finite, got {}", v, c);
1945 }
1946 for v in [-5.0_f64, -1.0, -1e-12, 0.0] {
1947 let vars = HashMap::from([("x", v)]);
1948 let s = d_sqrt.eval(&vars).unwrap();
1949 assert!(s.is_finite(), "safe_sqrt'({}) should be finite, got {}", v, s);
1950 }
1951 }
1952 }
1953
1954 #[test]
1955 fn safe_sqrt_eval() {
1956 sym! {
1957 let x = symbol("x");
1958 let f = safe_sqrt(x);
1959 let vars = HashMap::from([("x", 4.0)]);
1960 assert!((f.eval(&vars).unwrap() - 2.0).abs() < 1e-10);
1961 let vars = HashMap::from([("x", -1e-10)]);
1963 assert!(f.eval(&vars).unwrap().abs() < 1e-10);
1964 let vars = HashMap::from([("x", 0.0)]);
1966 assert!(f.eval(&vars).unwrap().abs() < 1e-10);
1967 }
1968 }
1969
1970 #[test]
1971 fn safe_sqrt_deriv_at_zero() {
1972 sym! {
1973 let x = symbol("x");
1974 let df = safe_sqrt(x).diff("x");
1975 let vars = HashMap::from([("x", 0.0)]);
1977 let v = df.eval(&vars).unwrap();
1978 assert!(v.is_finite(), "safe_sqrt derivative at 0 should be finite, got {}", v);
1979 }
1980 }
1981
1982 #[test]
1985 fn grad2_basic() {
1986 sym! {
1987 let g = grad2(|a, b| a * b);
1988 let x = symbol("x");
1989 let y = symbol("y");
1990 let [da, db] = g(x, y);
1991 assert_eq!(format!("{}", da), "y");
1992 assert_eq!(format!("{}", db), "x");
1993 }
1994 }
1995
1996 #[test]
1997 fn grad1_basic() {
1998 sym! {
1999 let g = grad1(|t| t * t);
2000 let x = symbol("x");
2001 let [dt] = g(x);
2002 assert_eq!(format!("{}", dt), "2 * x");
2003 }
2004 }
2005
2006 #[test]
2009 fn extern_func_display() {
2010 sym! {
2011 let x = symbol("x");
2012 let y = symbol("y");
2013 let f = rad_diff(x, y);
2014 assert_eq!(format!("{}", f), "rad_diff(x, y)");
2015 }
2016 }
2017
2018 #[test]
2019 fn extern_func_diff() {
2020 sym! {
2021 let x = symbol("x");
2022 let y = symbol("y");
2023 let f = rad_diff(x, y);
2024 assert_eq!(format!("{}", f.diff("x")), "1");
2025 assert_eq!(format!("{}", f.diff("y")), "-1");
2026 }
2027 }
2028
2029 #[test]
2030 fn extern_func_chain_rule() {
2031 sym! {
2032 let x = symbol("x");
2033 let y = symbol("y");
2034 let f = rad_diff(x * x, y);
2035 assert_eq!(format!("{}", f.diff("x")), "2 * x");
2036 }
2037 }
2038
2039 #[test]
2040 fn extern_func_eval() {
2041 sym! {
2043 let x = symbol("x");
2044 let y = symbol("y");
2045 let f = rad_diff(x, y);
2046 let vars = HashMap::from([("x", 0.3), ("y", 0.1)]);
2047 let v = f.eval(&vars).unwrap();
2048 assert!((v - 0.2).abs() < 1e-10);
2049 }
2050 }
2051
2052 #[test]
2053 fn extern_func_eval_wrapping() {
2054 sym! {
2056 let x = symbol("x");
2057 let f = rad_diff(constant(0.0), x);
2058 let vars = HashMap::from([("x", 2.0 * std::f64::consts::PI)]);
2059 let v = f.eval(&vars).unwrap();
2060 assert!(v.abs() < 1e-10, "rad_diff(0, 2*pi) = {}, expected 0", v);
2061 }
2062 }
2063
2064 #[test]
2065 fn extern_func_to_rust() {
2066 sym! {
2067 let x = symbol("x");
2068 let y = symbol("y");
2069 let f = rad_diff(x, y);
2070 let code = f.to_rust("f64");
2071 assert_eq!(code, "arael::utils::rad_diff(x, y)");
2072 }
2073 }
2074
2075 #[test]
2076 fn extern_func_latex() {
2077 sym! {
2078 let x = symbol("x");
2079 let y = symbol("y");
2080 let f = rad_diff(x, y);
2081 assert_eq!(f.to_latex(), "\\operatorname{rad\\_diff}\\left(x, y\\right)");
2082 }
2083 }
2084
2085 #[test]
2086 fn extern_func_subs() {
2087 sym! {
2088 let x = symbol("x");
2089 let y = symbol("y");
2090 let f = rad_diff(x, y);
2091 let g = f.subs("x", &constant(1.0));
2092 assert_eq!(format!("{}", g), "rad_diff(1, y)");
2093 }
2094 }
2095
2096 #[test]
2097 fn extern_func_no_const_fold() {
2098 sym! {
2100 let f = rad_diff(constant(1.0), constant(2.0));
2101 let s = f.simplify();
2102 assert_eq!(format!("{}", s), "rad_diff(1, 2)");
2103 }
2104 }
2105
2106 #[test]
2107 fn extern_func_no_expand() {
2108 sym! {
2110 let x = symbol("x");
2111 let y = symbol("y");
2112 let f = rad_diff(x + 1.0, y);
2113 let expanded = f.expand();
2114 assert_eq!(format!("{}", expanded), "rad_diff(x + 1, y)");
2115 }
2116 }
2117
2118 #[test]
2119 fn extern_func_free_vars() {
2120 sym! {
2121 let x = symbol("x");
2122 let y = symbol("y");
2123 let f = rad_diff(x, y);
2124 let vars = f.free_vars();
2125 assert!(vars.contains("x"));
2126 assert!(vars.contains("y"));
2127 assert!(!vars.contains("__a"));
2128 assert!(!vars.contains("__b"));
2129 }
2130 }
2131
2132 #[test]
2133 fn rad_sum_diff() {
2134 sym! {
2135 let x = symbol("x");
2136 let y = symbol("y");
2137 let f = rad_sum(x, y);
2138 assert_eq!(format!("{}", f.diff("x")), "1");
2139 assert_eq!(format!("{}", f.diff("y")), "1");
2140 }
2141 }
2142
2143 #[test]
2144 fn rad_sum_to_rust() {
2145 sym! {
2146 let x = symbol("x");
2147 let y = symbol("y");
2148 let f = rad_sum(x, y);
2149 assert_eq!(f.to_rust("f64"), "arael::utils::rad_sum(x, y)");
2150 }
2151 }
2152
2153 #[test]
2154 fn extern_func_def() {
2155 sym! {
2156 fn my_eval(args: &[f64]) -> f64 { args[0] - args[1] }
2157 let my_diff = extern_func2("my_diff", "my_mod::diff",
2158 grad2(|a, b| a - b), my_eval);
2159 let x = symbol("x");
2160 let y = symbol("y");
2161 let f = my_diff(x, y);
2162 assert_eq!(format!("{}", f), "my_diff(x, y)");
2163 assert_eq!(format!("{}", f.diff("x")), "1");
2164 assert_eq!(format!("{}", f.diff("y")), "-1");
2165 assert_eq!(f.to_rust("f64"), "my_mod::diff(x, y)");
2166 }
2167 }
2168
2169 #[test]
2172 fn heaviside_eval() {
2173 let vars = HashMap::from([("x", 0.0)]);
2174 sym! {
2175 let x = symbol("x");
2176 let h = heaviside(x);
2177 assert_eq!(h.eval(&HashMap::from([("x", -1.0)])).unwrap(), 0.0);
2178 assert_eq!(h.eval(&vars).unwrap(), 1.0);
2179 assert_eq!(h.eval(&HashMap::from([("x", 3.0)])).unwrap(), 1.0);
2180 }
2181 }
2182
2183 #[test]
2184 fn heaviside_diff() {
2185 sym! {
2186 let x = symbol("x");
2187 assert_eq!(format!("{}", heaviside(x).diff("x")), "0");
2188 assert_eq!(format!("{}", heaviside(x * x - 1.0).diff("x")), "0");
2189 }
2190 }
2191
2192 #[test]
2193 fn heaviside_display() {
2194 sym! {
2195 let x = symbol("x");
2196 assert_eq!(format!("{}", heaviside(x)), "H(x)");
2197 }
2198 }
2199
2200 #[test]
2201 fn heaviside_composition_diff() {
2202 sym! {
2203 let x = symbol("x");
2204 let f = heaviside(1.0 - x) * x * x;
2206 assert_eq!(format!("{}", f.diff("x")), "2 * x * H(-x + 1)");
2207 }
2208 }
2209
2210 #[test]
2213 fn clamp_eval() {
2214 sym! {
2215 let x = symbol("x");
2216 let f = clamp(x, c(0.0), c(1.0));
2217 assert_eq!(f.eval(&HashMap::from([("x", 0.5)])).unwrap(), 0.5);
2218 assert_eq!(f.eval(&HashMap::from([("x", -2.0)])).unwrap(), 0.0);
2219 assert_eq!(f.eval(&HashMap::from([("x", 5.0)])).unwrap(), 1.0);
2220 }
2221 }
2222
2223 #[test]
2224 fn clamp_diff_passthrough() {
2225 sym! {
2226 let x = symbol("x");
2227 assert_eq!(format!("{}", clamp(x, c(0.0), c(1.0)).diff("x")), "1");
2229 assert_eq!(format!("{}", clamp(x * x, c(0.0), c(1.0)).diff("x")), "2 * x");
2231 }
2232 }
2233
2234 #[test]
2235 fn clamp_display() {
2236 sym! {
2237 let x = symbol("x");
2238 assert_eq!(format!("{}", clamp(x, c(0.0), c(1.0))), "clamp(x, 0, 1)");
2239 }
2240 }
2241
2242 #[test]
2243 fn clamp_simplify_constants() {
2244 sym! {
2245 let f = clamp(c(5.0), c(0.0), c(1.0));
2246 assert_eq!(format!("{}", f.simplify()), "1");
2247 let g = clamp(c(-3.0), c(0.0), c(1.0));
2248 assert_eq!(format!("{}", g.simplify()), "0");
2249 let h = clamp(c(0.5), c(0.0), c(1.0));
2250 assert_eq!(format!("{}", h.simplify()), "0.5");
2251 }
2252 }
2253
2254 #[test]
2257 fn clamp_asin_eval() {
2258 sym! {
2259 let my_asin = simple_func1("my_asin", |t| asin(clamp(t, c(-1.0), c(1.0))));
2260 let x = symbol("x");
2261
2262 let f = my_asin(x);
2264 let val = f.eval(&HashMap::from([("x", 0.5)])).unwrap();
2265 assert!((val - 0.5_f64.asin()).abs() < 1e-10);
2266
2267 let val_hi = f.eval(&HashMap::from([("x", 1.5)])).unwrap();
2269 assert!((val_hi - std::f64::consts::FRAC_PI_2).abs() < 1e-10);
2270
2271 let val_lo = f.eval(&HashMap::from([("x", -1.5)])).unwrap();
2272 assert!((val_lo + std::f64::consts::FRAC_PI_2).abs() < 1e-10);
2273 }
2274 }
2275
2276 #[test]
2277 fn clamp_asin_diff() {
2278 sym! {
2279 let my_asin = simple_func1("my_asin", |t| asin(clamp(t, c(-1.0), c(1.0))));
2280 let x = symbol("x");
2281 let f = my_asin(x);
2282 let df = f.diff("x");
2284 let vars = HashMap::from([("x", 0.5)]);
2286 let dval = df.eval(&vars).unwrap();
2287 let expected = 1.0 / (1.0 - 0.25_f64).sqrt(); assert!((dval - expected).abs() < 1e-10);
2289 }
2290 }
2291
2292 #[test]
2293 fn heaviside_to_rust() {
2294 sym! {
2295 let x = symbol("x");
2296 assert_eq!(heaviside(x).to_rust("f64"), "x.heaviside()");
2297 }
2298 }
2299
2300 #[test]
2301 fn clamp_to_rust() {
2302 sym! {
2303 let x = symbol("x");
2304 assert_eq!(clamp(x, c(0.0), c(1.0)).to_rust("f64"), "x.clamp(0.0_f64, 1.0_f64)");
2305 }
2306 }
2307
2308 #[test]
2309 fn parse_heaviside() {
2310 let f = parse("H(x)").unwrap();
2311 assert_eq!(format!("{}", f), "H(x)");
2312 assert_eq!(format!("{}", f.diff("x")), "0");
2313 }
2314
2315 #[test]
2316 fn parse_clamp() {
2317 let f = parse("clamp(x, 0, 1)").unwrap();
2318 assert_eq!(format!("{}", f), "clamp(x, 0, 1)");
2319 assert_eq!(format!("{}", f.diff("x")), "1");
2320 }
2321
2322 #[test]
2325 fn named_const_pi_display() {
2326 assert_eq!(format!("{}", pi()), "pi");
2327 }
2328
2329 #[test]
2330 fn named_const_pi_eval() {
2331 let vars = HashMap::new();
2332 assert_eq!(pi().eval(&vars).unwrap(), std::f64::consts::PI);
2333 }
2334
2335 #[test]
2336 fn named_const_pi_diff() {
2337 assert_eq!(format!("{}", pi().diff("x")), "0");
2338 }
2339
2340 #[test]
2341 fn named_const_pi_codegen() {
2342 assert_eq!(pi().to_rust("f64"), "std::f64::consts::PI");
2343 assert_eq!(pi().to_rust("f32"), "std::f32::consts::PI");
2344 }
2345
2346 #[test]
2347 fn named_const_pi_latex() {
2348 assert_eq!(pi().to_latex(), "\\pi");
2349 }
2350
2351 #[test]
2352 fn named_const_epsilon_display() {
2353 assert_eq!(format!("{}", epsilon()), "epsilon");
2354 }
2355
2356 #[test]
2357 fn named_const_epsilon_eval() {
2358 let vars = HashMap::new();
2359 assert_eq!(epsilon().eval(&vars).unwrap(), f64::EPSILON);
2360 }
2361
2362 #[test]
2363 fn named_const_epsilon_codegen() {
2364 assert_eq!(epsilon().to_rust("f64"), "f64::EPSILON");
2365 assert_eq!(epsilon().to_rust("f32"), "f32::EPSILON");
2366 }
2367
2368 #[test]
2369 fn named_const_euler_display() {
2370 assert_eq!(format!("{}", euler()), "e");
2371 }
2372
2373 #[test]
2374 fn named_const_euler_eval() {
2375 let vars = HashMap::new();
2376 assert_eq!(euler().eval(&vars).unwrap(), std::f64::consts::E);
2377 }
2378
2379 #[test]
2380 fn named_const_euler_codegen() {
2381 assert_eq!(euler().to_rust("f64"), "std::f64::consts::E");
2382 }
2383
2384 #[test]
2385 fn named_const_epsilon_survives_simplification() {
2386 sym! {
2387 let x = symbol("x");
2388 let f = (x + epsilon()).simplify();
2389 assert_eq!(format!("{}", f), "x + epsilon");
2390 }
2391 }
2392
2393 #[test]
2394 fn named_const_not_free_var() {
2395 sym! {
2396 let x = symbol("x");
2397 let f = x + pi();
2398 let vars = f.free_vars();
2399 assert!(vars.contains("x"));
2400 assert!(!vars.contains("pi"));
2401 }
2402 }
2403
2404 #[test]
2405 fn named_const_custom() {
2406 let tau = named_const("tau", std::f64::consts::TAU,
2407 "std::f32::consts::TAU", "std::f64::consts::TAU", "\\tau");
2408 assert_eq!(format!("{}", tau), "tau");
2409 let vars = HashMap::new();
2410 assert_eq!(tau.eval(&vars).unwrap(), std::f64::consts::TAU);
2411 assert_eq!(tau.to_rust("f64"), "std::f64::consts::TAU");
2412 assert_eq!(tau.to_latex(), "\\tau");
2413 }
2414
2415 #[test]
2418 fn named_const_pi_add_pi() {
2419 sym! {
2420 let f = (pi() + pi()).simplify();
2421 assert_eq!(format!("{}", f), "2 * pi");
2422 }
2423 }
2424
2425 #[test]
2426 fn named_const_pi_sub_pi() {
2427 sym! {
2428 let f = (pi() - pi()).simplify();
2429 assert_eq!(format!("{}", f), "0");
2430 }
2431 }
2432
2433 #[test]
2434 fn named_const_pi_mul_pi() {
2435 sym! {
2436 let f = (pi() * pi()).simplify();
2437 assert_eq!(format!("{}", f), "pi^2");
2438 }
2439 }
2440
2441 #[test]
2442 fn named_const_epsilon_add() {
2443 sym! {
2444 let x = symbol("x");
2445 let f = (x + epsilon() + epsilon()).simplify();
2446 assert_eq!(format!("{}", f), "x + 2 * epsilon");
2447 }
2448 }
2449
2450 #[test]
2453 fn trig_sin_pi() {
2454 sym! { assert_eq!(format!("{}", sin(pi()).simplify()), "0"); }
2455 }
2456
2457 #[test]
2458 fn trig_cos_pi() {
2459 sym! { assert_eq!(format!("{}", cos(pi()).simplify()), "-1"); }
2460 }
2461
2462 #[test]
2463 fn trig_sin_pi_half() {
2464 sym! { assert_eq!(format!("{}", sin(pi() / 2.0).simplify()), "1"); }
2465 }
2466
2467 #[test]
2468 fn trig_cos_pi_half() {
2469 sym! { assert_eq!(format!("{}", cos(pi() / 2.0).simplify()), "0"); }
2470 }
2471
2472 #[test]
2473 fn trig_sin_pi_quarter() {
2474 sym! {
2475 let f = sin(pi() / 4.0).simplify();
2476 let vars = HashMap::new();
2477 let v = f.eval(&vars).unwrap();
2478 assert!((v - std::f64::consts::FRAC_1_SQRT_2).abs() < 1e-10);
2479 }
2480 }
2481
2482 #[test]
2483 fn trig_cos_pi_third() {
2484 sym! {
2485 let f = cos(pi() / 3.0).simplify();
2486 assert_eq!(format!("{}", f), "0.5");
2487 }
2488 }
2489
2490 #[test]
2491 fn trig_sin_2pi() {
2492 sym! { assert_eq!(format!("{}", sin(2.0 * pi()).simplify()), "0"); }
2493 }
2494
2495 #[test]
2496 fn trig_cos_2pi() {
2497 sym! { assert_eq!(format!("{}", cos(2.0 * pi()).simplify()), "1"); }
2498 }
2499
2500 #[test]
2501 fn trig_tan_pi() {
2502 sym! { assert_eq!(format!("{}", tan(pi()).simplify()), "0"); }
2503 }
2504
2505 #[test]
2506 fn trig_sin_pi_sixth() {
2507 sym! { assert_eq!(format!("{}", sin(pi() / 6.0).simplify()), "0.5"); }
2508 }
2509
2510 #[test]
2513 fn ln_e() {
2514 sym! { assert_eq!(format!("{}", ln(euler()).simplify()), "1"); }
2515 }
2516
2517 #[test]
2520 fn sym_macro_bare_pi() {
2521 sym! {
2522 let x = symbol("x");
2523 let f = 2.0 * pi * x;
2524 assert_eq!(format!("{}", f), "2 * x * pi");
2525 }
2526 }
2527
2528 #[test]
2529 fn sym_macro_bare_epsilon() {
2530 sym! {
2531 let x = symbol("x");
2532 let f = x * x + epsilon;
2533 assert_eq!(format!("{}", f), "x^2 + epsilon");
2534 }
2535 }
2536
2537 #[test]
2538 fn sym_macro_pi_call_still_works() {
2539 sym! {
2541 let f = pi();
2542 assert_eq!(format!("{}", f), "pi");
2543 }
2544 }
2545
2546 #[test]
2547 fn ln_e_pow_x() {
2548 sym! {
2549 let x = symbol("x");
2550 let f = ln(pow(euler(), x)).simplify();
2551 assert_eq!(format!("{}", f), "x");
2552 }
2553 }
2554}
2555