1#![deny(missing_docs)]
6use std::ops::{Add, Div, Mul, Sub};
7use std::iter::Sum;
8
9#[derive(Debug, Clone, PartialEq)]
10enum Operation {
11 None,
12 Add,
13 Sub,
14 Mul,
15 Div,
16 Tanh,
17 Exp,
18 Pow,
19 ReLU,
20 Log,
21 Neg,
22}
23
24impl Operation {
25 fn assert_is_type(&self, expr_type: ExprType) {
26 match self {
27 Operation::None => assert_eq!(expr_type, ExprType::Leaf),
28 Operation::Tanh | Operation::Exp | Operation::ReLU | Operation::Log | Operation::Neg => assert_eq!(expr_type, ExprType::Unary),
29 _ => assert_eq!(expr_type, ExprType::Binary),
30 }
31 }
32}
33
34#[derive(Debug, PartialEq)]
35enum ExprType {
36 Leaf,
37 Unary,
38 Binary,
39}
40
41#[derive(Debug, Clone)]
51pub struct Expr {
52 operand1: Option<Box<Expr>>,
53 operand2: Option<Box<Expr>>,
54 operation: Operation,
55 pub result: f64,
57 pub is_learnable: bool,
59 grad: f64,
60 pub name: String,
62}
63
64impl Expr {
65 pub fn new_leaf(value: f64, name: &str) -> Expr {
74 Expr {
75 operand1: None,
76 operand2: None,
77 operation: Operation::None,
78 result: value,
79 is_learnable: true,
80 grad: 0.0,
81 name: name.to_string(),
82 }
83 }
84
85 fn expr_type(&self) -> ExprType {
86 match self.operation {
87 Operation::None => ExprType::Leaf,
88 Operation::Tanh | Operation::Exp | Operation::ReLU | Operation::Log | Operation::Neg => ExprType::Unary,
89 _ => ExprType::Binary,
90 }
91 }
92
93 fn new_unary(operand: Expr, operation: Operation, result: f64, name: &str) -> Expr {
94 operation.assert_is_type(ExprType::Unary);
95 Expr {
96 operand1: Some(Box::new(operand)),
97 operand2: None,
98 operation,
99 result,
100 is_learnable: false,
101 grad: 0.0,
102 name: name.to_string(),
103 }
104 }
105
106 fn new_binary(operand1: Expr, operand2: Expr, operation: Operation, result: f64, name: &str) -> Expr {
107 operation.assert_is_type(ExprType::Binary);
108 Expr {
109 operand1: Some(Box::new(operand1)),
110 operand2: Some(Box::new(operand2)),
111 operation,
112 result,
113 is_learnable: false,
114 grad: 0.0,
115 name: name.to_string(),
116 }
117 }
118
119 pub fn tanh(self, name: &str) -> Expr {
131 let result = self.result.tanh();
132 Expr::new_unary(self, Operation::Tanh, result, name)
133 }
134
135 pub fn relu(self, name: &str) -> Expr {
147 let result = self.result.max(0.0);
148 Expr::new_unary(self, Operation::ReLU, result, name)
149 }
150
151 pub fn exp(self, name: &str) -> Expr {
163 let result = self.result.exp();
164 Expr::new_unary(self, Operation::Exp, result, name)
165 }
166
167 pub fn pow(self, exponent: Expr, name: &str) -> Expr {
180 let result = self.result.powf(exponent.result);
181 Expr::new_binary(self, exponent, Operation::Pow, result, name)
182 }
183
184 pub fn log(self, name: &str) -> Expr {
196 let result = self.result.ln();
197 Expr::new_unary(self, Operation::Log, result, name)
198 }
199
200 pub fn neg(self, name: &str) -> Expr {
212 let result = -self.result;
213 Expr::new_unary(self, Operation::Neg, result, name)
214 }
215
216 pub fn recalculate(&mut self) {
231 match self.expr_type() {
232 ExprType::Leaf => {}
233 ExprType::Unary => {
234 let operand1 = self.operand1.as_mut().expect("Unary expression did not have an operand");
235 operand1.recalculate();
236
237 self.result = match self.operation {
238 Operation::Tanh => operand1.result.tanh(),
239 Operation::Exp => operand1.result.exp(),
240 Operation::ReLU => operand1.result.max(0.0),
241 Operation::Log => operand1.result.ln(),
242 Operation::Neg => -operand1.result,
243 _ => panic!("Invalid unary operation {:?}", self.operation),
244 };
245 }
246 ExprType::Binary => {
247 let operand1 = self.operand1.as_mut().expect("Binary expression did not have an operand");
248 let operand2 = self.operand2.as_mut().expect("Binary expression did not have a second operand");
249
250 operand1.recalculate();
251 operand2.recalculate();
252
253 self.result = match self.operation {
254 Operation::Add => operand1.result + operand2.result,
255 Operation::Sub => operand1.result - operand2.result,
256 Operation::Mul => operand1.result * operand2.result,
257 Operation::Div => operand1.result / operand2.result,
258 Operation::Pow => operand1.result.powf(operand2.result),
259 _ => panic!("Invalid binary operation: {:?}", self.operation),
260 };
261 }
262 }
263 }
264
265 pub fn learn(&mut self, learning_rate: f64) {
287 self.grad = 1.0;
288 self.learn_internal(learning_rate);
289 }
290
291 fn learn_internal(&mut self, learning_rate: f64) {
292 match self.expr_type() {
293 ExprType::Leaf => {
294 if self.is_learnable {
297 self.result -= learning_rate * self.grad;
298 }
299 }
300 ExprType::Unary => {
301 let operand1 = self.operand1.as_mut().expect("Unary expression did not have an operand");
302
303 match self.operation {
304 Operation::Tanh => {
305 let tanh_grad = 1.0 - (self.result * self.result);
306 operand1.grad = self.grad * tanh_grad;
307 }
308 Operation::Exp => {
309 operand1.grad = self.grad * self.result;
310 }
311 Operation::ReLU => {
312 operand1.grad = self.grad * if self.result > 0.0 { 1.0 } else { 0.0 };
313 }
314 Operation::Log => {
315 operand1.grad = self.grad / operand1.result;
316 }
317 Operation::Neg => {
318 operand1.grad = -self.grad;
319 }
320 _ => panic!("Invalid unary operation {:?}", self.operation),
321 }
322
323 operand1.learn_internal(learning_rate);
324 }
325 ExprType::Binary => {
326 let operand1 = self.operand1.as_mut().expect("Binary expression did not have an operand");
327 let operand2 = self.operand2.as_mut().expect("Binary expression did not have a second operand");
328
329 match self.operation {
330 Operation::Add => {
331 operand1.grad = self.grad;
332 operand2.grad = self.grad;
333 }
334 Operation::Sub => {
335 operand1.grad = self.grad;
336 operand2.grad = -self.grad;
337 }
338 Operation::Mul => {
339 let operand2_result = operand2.result;
340 let operand1_result = operand1.result;
341
342 operand1.grad = self.grad * operand2_result;
343 operand2.grad = self.grad * operand1_result;
344 }
345 Operation::Div => {
346 let operand2_result = operand2.result;
347 let operand1_result = operand1.result;
348
349 operand1.grad = self.grad / operand2_result;
350 operand2.grad = -self.grad * operand1_result / (operand2_result * operand2_result);
351 }
352 Operation::Pow => {
353 let exponent = operand2.result;
354 let base = operand1.result;
355
356 operand1.grad = self.grad * exponent * base.powf(exponent - 1.0);
357 operand2.grad = self.grad * base.powf(exponent) * base.ln();
358 }
359 _ => panic!("Invalid binary operation: {:?}", self.operation),
360 }
361
362 operand1.learn_internal(learning_rate);
363 operand2.learn_internal(learning_rate);
364 }
365 }
366 }
367
368 pub fn find(&self, name: &str) -> Option<&Expr> {
384 if self.name == name {
385 return Some(self);
386 }
387
388 match self.expr_type() {
389 ExprType::Leaf => None,
390 ExprType::Unary => {
391 let operand1 = self.operand1.as_ref().expect("Unary expression did not have an operand");
392 operand1.find(name)
393 }
394 ExprType::Binary => {
395 let operand1 = self.operand1.as_ref().expect("Binary expression did not have an operand");
396 let operand2 = self.operand2.as_ref().expect("Binary expression did not have a second operand");
397
398 let result = operand1.find(name);
399 if result.is_some() {
400 return result;
401 }
402
403 operand2.find(name)
404 }
405 }
406 }
407}
408
409impl Add for Expr {
425 type Output = Expr;
426
427 fn add(self, other: Expr) -> Expr {
428 let result = self.result + other.result;
429 let name = &format!("({} + {})", self.name, other.name);
430 Expr::new_binary(self, other, Operation::Add, result, name)
431 }
432}
433
434impl Add<f64> for Expr {
448 type Output = Expr;
449
450 fn add(self, other: f64) -> Expr {
451 let operand2 = Expr::new_leaf(other, &other.to_string());
452 self + operand2
453 }
454}
455
456impl Add<Expr> for f64 {
470 type Output = Expr;
471
472 fn add(self, other: Expr) -> Expr {
473 let operand1 = Expr::new_leaf(self, &self.to_string());
474 operand1 + other
475 }
476}
477
478impl Mul for Expr {
495 type Output = Expr;
496
497 fn mul(self, other: Expr) -> Expr {
498 let result = self.result * other.result;
499 let name = &format!("({} * {})", self.name, other.name);
500 Expr::new_binary(self, other, Operation::Mul, result, name)
501 }
502}
503
504impl Mul<f64> for Expr {
519 type Output = Expr;
520
521 fn mul(self, other: f64) -> Expr {
522 let operand2 = Expr::new_leaf(other, &other.to_string());
523 self * operand2
524 }
525}
526
527impl Mul<Expr> for f64 {
542 type Output = Expr;
543
544 fn mul(self, other: Expr) -> Expr {
545 let operand1 = Expr::new_leaf(self, &self.to_string());
546 operand1 * other
547 }
548}
549
550impl Sub for Expr {
567 type Output = Expr;
568
569 fn sub(self, other: Expr) -> Expr {
570 let result = self.result - other.result;
571 let name = &format!("({} - {})", self.name, other.name);
572 Expr::new_binary(self, other, Operation::Sub, result, name)
573 }
574}
575
576impl Sub<f64> for Expr {
591 type Output = Expr;
592
593 fn sub(self, other: f64) -> Expr {
594 let operand2 = Expr::new_leaf(other, &other.to_string());
595 self - operand2
596 }
597}
598
599impl Sub<Expr> for f64 {
614 type Output = Expr;
615
616 fn sub(self, other: Expr) -> Expr {
617 let operand1 = Expr::new_leaf(self, &self.to_string());
618 operand1 - other
619 }
620}
621
622impl Div for Expr {
639 type Output = Expr;
640
641 fn div(self, other: Expr) -> Expr {
642 let result = self.result / other.result;
643 let name = &format!("({} / {})", self.name, other.name);
644 Expr::new_binary(self, other, Operation::Div, result, name)
645 }
646}
647
648impl Div<f64> for Expr {
663 type Output = Expr;
664
665 fn div(self, other: f64) -> Expr {
666 let operand2 = Expr::new_leaf(other, &other.to_string());
667 self / operand2
668 }
669}
670
671impl Sum for Expr {
692 fn sum<I>(iter: I) -> Self
693 where
694 I: Iterator<Item = Self>,
695 {
696 iter.reduce(|acc, x| acc + x)
697 .unwrap_or(Expr::new_leaf(0.0, "0.0"))
698 }
699}
700
701#[cfg(test)]
702mod tests {
703 use super::*;
704
705 fn assert_float_eq(f1: f64, f2: f64) {
706 let places = 7;
707 let tolerance = 10.0_f64.powi(-places);
708 assert!((f1 - f2).abs() < tolerance, "{} != {} (tol: {})", f1, f2, tolerance);
709 }
710
711 #[test]
712 fn test() {
713 let expr = Expr::new_leaf(1.0, "x");
714 assert_eq!(expr.result, 1.0);
715 }
716
717 #[test]
718 fn test_unary() {
719 let expr = Expr::new_leaf(1.0, "x");
720 let expr2 = Expr::new_unary(expr, Operation::Tanh, 1.1, "tanh(x)");
721
722 assert_eq!(expr2.result, 1.1);
723 assert_eq!(expr2.operand1.unwrap().result, 1.0);
724 }
725
726 #[test]
727 #[should_panic]
728 fn test_unary_expression_type_check() {
729 let expr = Expr::new_leaf(1.0, "x");
730 let _expr2 = Expr::new_unary(expr, Operation::Add, 1.1, "tanh(x)");
731 }
732
733 #[test]
734 fn test_binary() {
735 let expr = Expr::new_leaf(1.0, "x");
736 let expr2 = Expr::new_leaf(2.0, "y");
737 let expr3 = Expr::new_binary(expr, expr2, Operation::Add, 1.1, "x + y");
738
739 assert_eq!(expr3.result, 1.1);
740 assert_eq!(expr3.operand1.unwrap().result, 1.0);
741 assert_eq!(expr3.operand2.unwrap().result, 2.0);
742 }
743
744 #[test]
745 #[should_panic]
746 fn test_binary_expression_type_check() {
747 let expr = Expr::new_leaf(1.0, "x");
748 let expr2 = Expr::new_leaf(2.0, "y");
749 let _expr3 = Expr::new_binary(expr, expr2, Operation::Tanh, 3.0, "x + y");
750 }
751
752 #[test]
753 fn test_mixed_tree() {
754 let expr = Expr::new_leaf(1.0, "x");
755 let expr2 = Expr::new_leaf(2.0, "y");
756 let expr3 = Expr::new_binary(expr, expr2, Operation::Sub, 1.1, "x - y");
757 let expr4 = Expr::new_unary(expr3, Operation::Tanh, 1.2, "tanh(x - y)");
758
759 assert_eq!(expr4.result, 1.2);
760 let expr3 = expr4.operand1.unwrap();
761 assert_eq!(expr3.result, 1.1);
762 assert_eq!(expr3.operand1.unwrap().result, 1.0);
763 assert_eq!(expr3.operand2.unwrap().result, 2.0);
764 }
765
766 #[test]
767 fn test_tanh() {
768 let expr = Expr::new_leaf(1.0, "x");
769 let expr2 = expr.tanh("tanh(x)");
770
771 assert_eq!(expr2.result, 0.7615941559557649);
772 assert!(expr2.operand1.is_some());
773 assert_eq!(expr2.operand1.unwrap().result, 1.0);
774 assert_eq!(expr2.operation, Operation::Tanh);
775 assert!(expr2.operand2.is_none());
776
777 fn get_tanh(x: f64) -> f64 {
779 Expr::new_leaf(x, "x").tanh("tanh(x)").result
780 }
781
782 assert_float_eq(get_tanh(10.74), 0.9999999);
783 assert_float_eq(get_tanh(-10.74), -0.9999999);
784 assert_float_eq(get_tanh(0.0), 0.0);
785 }
786
787 #[test]
788 fn test_exp() {
789 let expr = Expr::new_leaf(1.0, "x");
790 let expr2 = expr.exp("exp(x)");
791
792 assert_eq!(expr2.result, 2.718281828459045);
793 assert!(expr2.operand1.is_some());
794 assert_eq!(expr2.operand1.unwrap().result, 1.0);
795 assert_eq!(expr2.operation, Operation::Exp);
796 assert!(expr2.operand2.is_none());
797 }
798
799 #[test]
800 fn test_relu() {
801 let expr = Expr::new_leaf(-1.0, "x");
803 let expr2 = expr.relu("relu(x)");
804
805 assert_eq!(expr2.result, 0.0);
806 assert!(expr2.operand1.is_some());
807 assert_eq!(expr2.operand1.unwrap().result, -1.0);
808 assert_eq!(expr2.operation, Operation::ReLU);
809 assert!(expr2.operand2.is_none());
810
811 let expr = Expr::new_leaf(1.0, "x");
813 let expr2 = expr.relu("relu(x)");
814
815 assert_eq!(expr2.result, 1.0);
816 assert!(expr2.operand1.is_some());
817 assert_eq!(expr2.operand1.unwrap().result, 1.0);
818 assert_eq!(expr2.operation, Operation::ReLU);
819 assert!(expr2.operand2.is_none());
820 }
821
822 #[test]
823 fn test_pow() {
824 let expr = Expr::new_leaf(2.0, "x");
825 let expr2 = Expr::new_leaf(3.0, "y");
826 let result = expr.pow(expr2, "x^y");
827
828 assert_eq!(result.result, 8.0);
829 assert!(result.operand1.is_some());
830 assert_eq!(result.operand1.unwrap().result, 2.0);
831 assert_eq!(result.operation, Operation::Pow);
832
833 assert!(result.operand2.is_some());
834 assert_eq!(result.operand2.unwrap().result, 3.0);
835 }
836
837 #[test]
838 fn test_add() {
839 let expr = Expr::new_leaf(1.0, "x");
840 let expr2 = Expr::new_leaf(2.0, "y");
841 let expr3 = expr + expr2;
842
843 assert_eq!(expr3.result, 3.0);
844 assert!(expr3.operand1.is_some());
845 assert_eq!(expr3.operand1.unwrap().result, 1.0);
846 assert!(expr3.operand2.is_some());
847 assert_eq!(expr3.operand2.unwrap().result, 2.0);
848 assert_eq!(expr3.operation, Operation::Add);
849 assert_eq!(expr3.name, "(x + y)");
850 }
851
852 #[test]
853 fn test_add_f64() {
854 let expr = Expr::new_leaf(1.0, "x");
855 let expr2 = expr + 2.0;
856
857 assert_eq!(expr2.result, 3.0);
858 assert!(expr2.operand1.is_some());
859 assert_eq!(expr2.operand1.unwrap().result, 1.0);
860 assert!(expr2.operand2.is_some());
861 assert_eq!(expr2.operand2.unwrap().result, 2.0);
862 assert_eq!(expr2.operation, Operation::Add);
863 assert_eq!(expr2.name, "(x + 2)");
864 }
865
866 #[test]
867 fn test_add_f64_expr() {
868 let expr = Expr::new_leaf(1.0, "x");
869 let expr2 = 2.0 + expr;
870
871 assert_eq!(expr2.result, 3.0);
872 assert!(expr2.operand1.is_some());
873 assert_eq!(expr2.operand1.unwrap().result, 2.0);
874 assert!(expr2.operand2.is_some());
875 assert_eq!(expr2.operand2.unwrap().result, 1.0);
876 assert_eq!(expr2.operation, Operation::Add);
877 assert_eq!(expr2.name, "(2 + x)");
878 }
879
880 #[test]
881 fn test_mul() {
882 let expr = Expr::new_leaf(2.0, "x");
883 let expr2 = Expr::new_leaf(3.0, "y");
884 let expr3 = expr * expr2;
885
886 assert_eq!(expr3.result, 6.0);
887 assert!(expr3.operand1.is_some());
888 assert_eq!(expr3.operand1.unwrap().result, 2.0);
889 assert!(expr3.operand2.is_some());
890 assert_eq!(expr3.operand2.unwrap().result, 3.0);
891 assert_eq!(expr3.operation, Operation::Mul);
892 assert_eq!(expr3.name, "(x * y)");
893 }
894
895 #[test]
896 fn test_mul_f64() {
897 let expr = Expr::new_leaf(2.0, "x");
898 let expr2 = expr * 3.0;
899
900 assert_eq!(expr2.result, 6.0);
901 assert!(expr2.operand1.is_some());
902 assert_eq!(expr2.operand1.unwrap().result, 2.0);
903 assert!(expr2.operand2.is_some());
904 assert_eq!(expr2.operand2.unwrap().result, 3.0);
905 assert_eq!(expr2.operation, Operation::Mul);
906 assert_eq!(expr2.name, "(x * 3)");
907 }
908
909 #[test]
910 fn test_mul_f64_expr() {
911 let expr = Expr::new_leaf(2.0, "x");
912 let expr2 = 3.0 * expr;
913
914 assert_eq!(expr2.result, 6.0);
915 assert!(expr2.operand1.is_some());
916 assert_eq!(expr2.operand1.unwrap().result, 3.0);
917 assert!(expr2.operand2.is_some());
918 assert_eq!(expr2.operand2.unwrap().result, 2.0);
919 assert_eq!(expr2.operation, Operation::Mul);
920 assert_eq!(expr2.name, "(3 * x)");
921 }
922
923 #[test]
924 fn test_sub() {
925 let expr = Expr::new_leaf(2.0, "x");
926 let expr2 = Expr::new_leaf(3.0, "y");
927 let expr3 = expr - expr2;
928
929 assert_eq!(expr3.result, -1.0);
930 assert!(expr3.operand1.is_some());
931 assert_eq!(expr3.operand1.unwrap().result, 2.0);
932 assert!(expr3.operand2.is_some());
933 assert_eq!(expr3.operand2.unwrap().result, 3.0);
934 assert_eq!(expr3.operation, Operation::Sub);
935 assert_eq!(expr3.name, "(x - y)");
936 }
937
938 #[test]
939 fn test_sub_f64() {
940 let expr = Expr::new_leaf(2.0, "x");
941 let expr2 = expr - 3.0;
942
943 assert_eq!(expr2.result, -1.0);
944 assert!(expr2.operand1.is_some());
945 assert_eq!(expr2.operand1.unwrap().result, 2.0);
946 assert!(expr2.operand2.is_some());
947 assert_eq!(expr2.operand2.unwrap().result, 3.0);
948 assert_eq!(expr2.operation, Operation::Sub);
949 assert_eq!(expr2.name, "(x - 3)");
950 }
951
952 #[test]
953 fn test_sub_f64_expr() {
954 let expr = Expr::new_leaf(2.0, "x");
955 let expr2 = 3.0 - expr;
956
957 assert_eq!(expr2.result, 1.0);
958 assert!(expr2.operand1.is_some());
959 assert_eq!(expr2.operand1.unwrap().result, 3.0);
960 assert!(expr2.operand2.is_some());
961 assert_eq!(expr2.operand2.unwrap().result, 2.0);
962 assert_eq!(expr2.operation, Operation::Sub);
963 assert_eq!(expr2.name, "(3 - x)");
964 }
965
966 #[test]
967 fn test_div() {
968 let expr = Expr::new_leaf(6.0, "x");
969 let expr2 = Expr::new_leaf(3.0, "y");
970 let expr3 = expr / expr2;
971
972 assert_eq!(expr3.result, 2.0);
973 assert!(expr3.operand1.is_some());
974 assert_eq!(expr3.operand1.unwrap().result, 6.0);
975 assert!(expr3.operand2.is_some());
976 assert_eq!(expr3.operand2.unwrap().result, 3.0);
977 assert_eq!(expr3.operation, Operation::Div);
978 assert_eq!(expr3.name, "(x / y)");
979 }
980
981 #[test]
982 fn test_div_f64() {
983 let expr = Expr::new_leaf(6.0, "x");
984 let expr2 = expr / 3.0;
985
986 assert_eq!(expr2.result, 2.0);
987 assert!(expr2.operand1.is_some());
988 assert_eq!(expr2.operand1.unwrap().result, 6.0);
989 assert!(expr2.operand2.is_some());
990 assert_eq!(expr2.operand2.unwrap().result, 3.0);
991 assert_eq!(expr2.operation, Operation::Div);
992 assert_eq!(expr2.name, "(x / 3)");
993 }
994
995 #[test]
996 fn test_log() {
997 let expr = Expr::new_leaf(2.0, "x");
998 let expr2 = expr.log("log(x)");
999
1000 assert_eq!(expr2.result, 0.6931471805599453);
1001 assert!(expr2.operand1.is_some());
1002 assert_eq!(expr2.operand1.unwrap().result, 2.0);
1003 assert_eq!(expr2.operation, Operation::Log);
1004 assert!(expr2.operand2.is_none());
1005 }
1006
1007 #[test]
1008 fn test_neg() {
1009 let expr = Expr::new_leaf(2.0, "x");
1010 let expr2 = expr.neg("neg(x)");
1011
1012 assert_eq!(expr2.result, -2.0);
1013 assert!(expr2.operand1.is_some());
1014 assert_eq!(expr2.operand1.unwrap().result, 2.0);
1015 assert_eq!(expr2.operation, Operation::Neg);
1016 assert!(expr2.operand2.is_none());
1017 }
1018
1019 #[test]
1020 fn test_backpropagation_add() {
1021 let operand1 = Expr::new_leaf(1.0, "x");
1022 let operand2 = Expr::new_leaf(2.0, "y");
1023 let mut expr3 = operand1 + operand2;
1024
1025 expr3.learn(1e-09);
1026
1027 let operand1 = expr3.operand1.unwrap();
1028 let operand2 = expr3.operand2.unwrap();
1029 assert_eq!(operand1.grad, 1.0);
1030 assert_eq!(operand2.grad, 1.0);
1031 }
1032
1033 #[test]
1034 fn test_backpropagation_sub() {
1035 let operand1 = Expr::new_leaf(1.0, "x");
1036 let operand2 = Expr::new_leaf(2.0, "y");
1037 let mut expr3 = operand1 - operand2;
1038
1039 expr3.learn(1e-09);
1040
1041 let operand1 = expr3.operand1.unwrap();
1042 let operand2 = expr3.operand2.unwrap();
1043 assert_eq!(operand1.grad, 1.0);
1044 assert_eq!(operand2.grad, -1.0);
1045 }
1046
1047 #[test]
1048 fn test_backpropagation_mul() {
1049 let operand1 = Expr::new_leaf(3.0, "x");
1050 let operand2 = Expr::new_leaf(4.0, "y");
1051 let mut expr3 = operand1 * operand2;
1052
1053 expr3.learn(1e-09);
1054
1055 let operand1 = expr3.operand1.unwrap();
1056 let operand2 = expr3.operand2.unwrap();
1057 assert_eq!(operand1.grad, 4.0);
1058 assert_eq!(operand2.grad, 3.0);
1059 }
1060
1061 #[test]
1062 fn test_backpropagation_div() {
1063 let operand1 = Expr::new_leaf(3.0, "x");
1064 let operand2 = Expr::new_leaf(4.0, "y");
1065 let mut expr3 = operand1 / operand2;
1066
1067 expr3.learn(1e-09);
1068
1069 let operand1 = expr3.operand1.unwrap();
1070 let operand2 = expr3.operand2.unwrap();
1071 assert_eq!(operand1.grad, 0.25);
1072 assert_eq!(operand2.grad, -0.1875);
1073 }
1074
1075 #[test]
1076 fn test_backpropagation_tanh() {
1077 let operand1 = Expr::new_leaf(0.0, "x");
1078 let mut expr2 = operand1.tanh("tanh(x)");
1079
1080 expr2.learn(1e-09);
1081
1082 let operand1 = expr2.operand1.unwrap();
1083 assert_eq!(operand1.grad, 1.0);
1084 }
1085
1086 #[test]
1087 fn test_backpropagation_relu() {
1088 let operand1 = Expr::new_leaf(-1.0, "x");
1089 let mut expr2 = operand1.relu("relu(x)");
1090
1091 expr2.learn(1e-09);
1092
1093 let operand1 = expr2.operand1.unwrap();
1094 assert_eq!(operand1.grad, 0.0);
1095 }
1096
1097 #[test]
1098 fn test_backpropagation_exp() {
1099 let operand1 = Expr::new_leaf(0.0, "x");
1100 let mut expr2 = operand1.exp("exp(x)");
1101
1102 expr2.learn(1e-09);
1103
1104 let operand1 = expr2.operand1.unwrap();
1105 assert_eq!(operand1.grad, 1.0);
1106 }
1107
1108 #[test]
1109 fn test_backpropagation_pow() {
1110 let operand1 = Expr::new_leaf(2.0, "x");
1111 let operand2 = Expr::new_leaf(3.0, "y");
1112 let mut expr3 = operand1.pow(operand2, "x^y");
1113
1114 expr3.learn(1e-09);
1115
1116 let operand1 = expr3.operand1.unwrap();
1117 let operand2 = expr3.operand2.unwrap();
1118 assert_eq!(operand1.grad, 12.0);
1119 assert_eq!(operand2.grad, 5.545177444479562);
1120 }
1121
1122 #[test]
1123 fn test_backpropagation_mixed_tree() {
1124 let operand1 = Expr::new_leaf(1.0, "x");
1125 let operand2 = Expr::new_leaf(2.0, "y");
1126 let expr3 = operand1 + operand2;
1127 let mut expr4 = expr3.tanh("tanh(x + y)");
1128
1129 expr4.learn(1e-09);
1130
1131 let expr3 = expr4.operand1.unwrap();
1132 let operand1 = expr3.operand1.unwrap();
1133 let operand2 = expr3.operand2.unwrap();
1134
1135 assert_eq!(expr3.grad, 0.009866037165440211);
1136 assert_eq!(operand1.grad, 0.009866037165440211);
1137 assert_eq!(operand2.grad, 0.009866037165440211);
1138 }
1139
1140 #[test]
1141 fn test_backpropagation_karpathys_example() {
1142 let x1 = Expr::new_leaf(2.0, "x1");
1143 let x2 = Expr::new_leaf(0.0, "x2");
1144 let w1 = Expr::new_leaf(-3.0, "w1");
1145 let w2 = Expr::new_leaf(1.0, "w2");
1146 let b = Expr::new_leaf(6.8813735870195432, "b");
1147
1148 let x1w1 = x1 * w1;
1149 let x2w2 = x2 * w2;
1150 let x1w1_x2w2 = x1w1 + x2w2;
1151 let n = x1w1_x2w2 + b;
1152 let mut o = n.tanh("tanh(n)");
1153
1154 o.learn(1e-09);
1155
1156 assert_eq!(o.operation, Operation::Tanh);
1157 assert_eq!(o.grad, 1.0);
1158
1159 let n = o.operand1.unwrap();
1160 assert_eq!(n.operation, Operation::Add);
1161 assert_float_eq(n.grad, 0.5);
1162
1163 let x1w1_x2w2 = n.operand1.unwrap();
1164 assert_eq!(x1w1_x2w2.operation, Operation::Add);
1165 assert_float_eq(x1w1_x2w2.grad, 0.5);
1166
1167 let b = n.operand2.unwrap();
1168 assert_eq!(b.operation, Operation::None);
1169 assert_float_eq(b.grad, 0.5);
1170
1171 let x1w1 = x1w1_x2w2.operand1.unwrap();
1172 assert_eq!(x1w1.operation, Operation::Mul);
1173 assert_float_eq(x1w1.grad, 0.5);
1174
1175 let x2w2 = x1w1_x2w2.operand2.unwrap();
1176 assert_eq!(x2w2.operation, Operation::Mul);
1177 assert_float_eq(x2w2.grad, 0.5);
1178
1179 let x1 = x1w1.operand1.unwrap();
1180 assert_eq!(x1.operation, Operation::None);
1181 assert_float_eq(x1.grad, -1.5);
1182
1183 let w1 = x1w1.operand2.unwrap();
1184 assert_eq!(w1.operation, Operation::None);
1185 assert_float_eq(w1.grad, 1.0);
1186
1187 let x2 = x2w2.operand1.unwrap();
1188 assert_eq!(x2.operation, Operation::None);
1189 assert_float_eq(x2.grad, 0.5);
1190
1191 let w2 = x2w2.operand2.unwrap();
1192 assert_eq!(w2.operation, Operation::None);
1193 assert_float_eq(w2.grad, 0.0);
1194 }
1195
1196 #[test]
1197 fn test_learn_simple() {
1198 let mut expr = Expr::new_leaf(1.0, "x");
1199 expr.learn(1e-01);
1200
1201 assert_float_eq(expr.result, 0.9);
1202 }
1203
1204 #[test]
1205 fn test_learn_skips_non_learnable() {
1206 let mut expr = Expr::new_leaf(1.0, "x");
1207 expr.is_learnable = false;
1208 expr.learn(1e-01);
1209
1210 assert_float_eq(expr.result, 1.0);
1211 }
1212
1213 #[test]
1214 fn test_find_simple() {
1215 let expr = Expr::new_leaf(1.0, "x");
1216 let expr2 = expr.tanh("tanh(x)");
1217
1218 let found = expr2.find("x");
1219 assert!(found.is_some());
1220 assert_eq!(found.unwrap().name, "x");
1221 }
1222
1223 #[test]
1224 fn test_find_not_found() {
1225 let expr = Expr::new_leaf(1.0, "x");
1226 let expr2 = expr.tanh("tanh(x)");
1227
1228 let found = expr2.find("y");
1229 assert!(found.is_none());
1230 }
1231
1232 #[test]
1233 fn test_sum_iterator() {
1234 let expr = Expr::new_leaf(1.0, "x");
1235 let expr2 = Expr::new_leaf(2.0, "y");
1236 let expr3 = Expr::new_leaf(3.0, "z");
1237
1238 let sum: Expr = vec![expr, expr2, expr3].into_iter().sum::<Expr>();
1239 assert_eq!(sum.result, 6.0);
1240 }
1241
1242 #[test]
1243 fn test_find_after_clone() {
1244 let expr = Expr::new_leaf(1.0, "x");
1245 let expr2 = expr.tanh("tanh(x)");
1246 let expr2_clone = expr2.clone();
1247
1248 let found = expr2_clone.find("x");
1249 assert!(found.is_some());
1250 assert_eq!(found.unwrap().name, "x");
1251 }
1252}