1#![deny(missing_docs)]
7use std::collections::VecDeque;
8use std::ops::{Add, Div, Mul, Sub};
9use std::iter::Sum;
10
11#[derive(Debug, Clone, Copy, PartialEq)]
12pub(crate) enum Operation {
13 None,
14 Add,
15 Sub,
16 Mul,
17 Div,
18 Tanh,
19 Exp,
20 Pow,
21 ReLU,
22 Log,
23 Neg,
24}
25
26impl Operation {
27 fn assert_is_type(&self, expr_type: ExprType) {
28 match self {
29 Operation::None => assert_eq!(expr_type, ExprType::Leaf),
30 Operation::Tanh | Operation::Exp | Operation::ReLU | Operation::Log | Operation::Neg => assert_eq!(expr_type, ExprType::Unary),
31 _ => assert_eq!(expr_type, ExprType::Binary),
32 }
33 }
34}
35
36#[derive(Debug, PartialEq)]
37pub(crate) enum ExprType {
38 Leaf,
39 Unary,
40 Binary,
41}
42
43#[derive(Debug, Clone)]
53pub struct Expr {
54 pub(crate) operand1: Option<Box<Expr>>,
55 pub(crate) operand2: Option<Box<Expr>>,
56 pub(crate) operation: Operation,
58 pub result: f64,
60 pub is_learnable: bool,
62 pub(crate) grad: f64,
63 pub name: Option<String>,
65}
66
67impl Expr {
68 pub fn new_leaf(value: f64) -> Expr {
77 Expr {
78 operand1: None,
79 operand2: None,
80 operation: Operation::None,
81 result: value,
82 is_learnable: true,
83 grad: 0.0,
84 name: None,
85 }
86 }
87
88 pub fn new_leaf_with_name(value: f64, name: &str) -> Expr {
99 let mut expr = Expr::new_leaf(value);
100 expr.name = Some(name.to_string());
101 expr
102 }
103
104 pub(crate) fn expr_type(&self) -> ExprType {
105 match self.operation {
106 Operation::None => ExprType::Leaf,
107 Operation::Tanh | Operation::Exp | Operation::ReLU | Operation::Log | Operation::Neg => ExprType::Unary,
108 _ => ExprType::Binary,
109 }
110 }
111
112 fn new_unary(operand: Expr, operation: Operation, result: f64) -> Expr {
113 operation.assert_is_type(ExprType::Unary);
114 Expr {
115 operand1: Some(Box::new(operand)),
116 operand2: None,
117 operation,
118 result,
119 is_learnable: false,
120 grad: 0.0,
121 name: None,
122 }
123 }
124
125 fn new_binary(operand1: Expr, operand2: Expr, operation: Operation, result: f64) -> Expr {
126 operation.assert_is_type(ExprType::Binary);
127 Expr {
128 operand1: Some(Box::new(operand1)),
129 operand2: Some(Box::new(operand2)),
130 operation,
131 result,
132 is_learnable: false,
133 grad: 0.0,
134 name: None,
135 }
136 }
137
138 pub fn tanh(self) -> Expr {
150 let result = self.result.tanh();
151 Expr::new_unary(self, Operation::Tanh, result)
152 }
153
154 pub fn relu(self) -> Expr {
166 let result = self.result.max(0.0);
167 Expr::new_unary(self, Operation::ReLU, result)
168 }
169
170 pub fn exp(self) -> Expr {
182 let result = self.result.exp();
183 Expr::new_unary(self, Operation::Exp, result)
184 }
185
186 pub fn pow(self, exponent: Expr) -> Expr {
199 let result = self.result.powf(exponent.result);
200 Expr::new_binary(self, exponent, Operation::Pow, result)
201 }
202
203 pub fn log(self) -> Expr {
215 let result = self.result.ln();
216 Expr::new_unary(self, Operation::Log, result)
217 }
218
219 pub fn neg(self) -> Expr {
231 let result = -self.result;
232 Expr::new_unary(self, Operation::Neg, result)
233 }
234
235 pub fn recalculate(&mut self) {
266 match self.expr_type() {
271 ExprType::Leaf => {}
272 ExprType::Unary => {
273 let operand1 = self.operand1.as_mut().expect("Unary expression did not have an operand");
274 operand1.recalculate();
275
276 self.result = match self.operation {
277 Operation::Tanh => operand1.result.tanh(),
278 Operation::Exp => operand1.result.exp(),
279 Operation::ReLU => operand1.result.max(0.0),
280 Operation::Log => operand1.result.ln(),
281 Operation::Neg => -operand1.result,
282 _ => panic!("Invalid unary operation {:?}", self.operation),
283 };
284 }
285 ExprType::Binary => {
286 let operand1 = self.operand1.as_mut().expect("Binary expression did not have an operand");
287 let operand2 = self.operand2.as_mut().expect("Binary expression did not have a second operand");
288
289 operand1.recalculate();
290 operand2.recalculate();
291
292 self.result = match self.operation {
293 Operation::Add => operand1.result + operand2.result,
294 Operation::Sub => operand1.result - operand2.result,
295 Operation::Mul => operand1.result * operand2.result,
296 Operation::Div => operand1.result / operand2.result,
297 Operation::Pow => operand1.result.powf(operand2.result),
298 _ => panic!("Invalid binary operation: {:?}", self.operation),
299 };
300 }
301 }
302 }
303
304 pub fn learn(&mut self, learning_rate: f64) {
326 self.grad = 1.0;
327
328 let mut queue = VecDeque::from([self]);
329
330 while let Some(node) = queue.pop_front() {
331 match node.expr_type() {
332 ExprType::Leaf => {
333 node.learn_internal_leaf(learning_rate);
334 }
335 ExprType::Unary => {
336 let operand1 = node.operand1.as_mut().expect("Unary expression did not have an operand");
337 operand1.adjust_grad_unary(&node.operation, node.grad, node.result);
338 queue.push_back(operand1);
339 }
340 ExprType::Binary => {
341 let operand1 = node.operand1.as_mut().expect("Binary expression did not have an operand");
342 let operand2 = node.operand2.as_mut().expect("Binary expression did not have a second operand");
343
344 operand1.adjust_grad_binary_op1(&node.operation, node.grad, operand2);
345 operand2.adjust_grad_binary_op2(&node.operation, node.grad, operand1);
346
347 queue.push_back(operand1);
348 queue.push_back(operand2);
349 }
350 }
351 }
352 }
353
354 fn learn_internal_leaf(&mut self, learning_rate: f64) {
355 if self.is_learnable {
358 self.result -= learning_rate * self.grad;
359 }
360 }
361
362 fn adjust_grad_unary(&mut self, child_operation: &Operation, child_grad: f64, child_result: f64) {
363 match child_operation {
364 Operation::Tanh => {
365 let tanh_grad = 1.0 - (child_result * child_result);
366 self.grad = child_grad * tanh_grad;
367 }
368 Operation::Exp => {
369 self.grad = child_grad * child_result;
370 }
371 Operation::ReLU => {
372 self.grad = child_grad * if child_result > 0.0 { 1.0 } else { 0.0 };
373 }
374 Operation::Log => {
375 self.grad = child_grad / child_result;
376 }
377 Operation::Neg => {
378 self.grad = -child_grad;
379 }
380 _ => panic!("Invalid unary operation {:?}", child_operation),
381 }
382 }
383
384 fn adjust_grad_binary_op1(&mut self, child_operation: &Operation, child_grad: f64, operand2: &Expr) {
385 match child_operation {
386 Operation::Add => {
387 self.grad = child_grad;
388 }
389 Operation::Sub => {
390 self.grad = child_grad;
391 }
392 Operation::Mul => {
393 let operand2_result = operand2.result;
394
395 self.grad = child_grad * operand2_result;
396 }
397 Operation::Div => {
398 let operand2_result = operand2.result;
399
400 self.grad = child_grad / operand2_result;
401 }
402 Operation::Pow => {
403 let exponent = operand2.result;
404 let base = self.result;
405
406 self.grad = child_grad * exponent * base.powf(exponent - 1.0);
407 }
408 _ => panic!("Invalid binary operation: {:?}", child_operation),
409 }
410 }
411
412 fn adjust_grad_binary_op2(&mut self,child_operation: &Operation, child_grad: f64, operand1: &Expr) {
413 match child_operation {
414 Operation::Add => {
415 self.grad = child_grad;
416 }
417 Operation::Sub => {
418 self.grad = -child_grad;
419 }
420 Operation::Mul => {
421 let operand1_result = operand1.result;
422 self.grad = child_grad * operand1_result;
423 }
424 Operation::Div => {
425 let operand2_result = self.result;
426 let operand1_result = operand1.result;
427
428 self.grad = -child_grad * operand1_result / (operand2_result * operand2_result);
429 }
430 Operation::Pow => {
431 let exponent = self.result;
432 let base = operand1.result;
433
434 self.grad = child_grad * base.powf(exponent) * base.ln();
435 }
436 _ => panic!("Invalid binary operation: {:?}", child_operation),
437 }
438 }
439
440 pub fn find(&self, name: &str) -> Option<&Expr> {
456 let mut stack = vec![self];
457
458 while let Some(node) = stack.pop() {
459 if node.name == Some(name.to_string()) {
460 return Some(node);
461 }
462
463 if let Some(operand1) = node.operand1.as_ref() {
464 stack.push(operand1);
465 }
466 if let Some(operand2) = node.operand2.as_ref() {
467 stack.push(operand2);
468 }
469 }
470
471 None
472 }
473
474 pub fn find_mut(&mut self, name: &str) -> Option<&mut Expr> {
492 let mut stack = vec![self];
493
494 while let Some(node) = stack.pop() {
495 if node.name == Some(name.to_string()) {
496 return Some(node);
497 }
498
499 if let Some(operand1) = node.operand1.as_mut() {
500 stack.push(operand1);
501 }
502 if let Some(operand2) = node.operand2.as_mut() {
503 stack.push(operand2);
504 }
505 }
506
507 None
508 }
509
510 pub fn parameter_count(&self, learnable_only: bool) -> usize {
526 let mut stack = vec![self];
527 let mut count = 0;
528
529 while let Some(node) = stack.pop() {
530 if node.is_learnable || !learnable_only {
531 count += 1;
532 }
533
534 if let Some(operand1) = node.operand1.as_ref() {
535 stack.push(operand1);
536 }
537 if let Some(operand2) = node.operand2.as_ref() {
538 stack.push(operand2);
539 }
540 }
541
542 count
543 }
544}
545
546impl Add for Expr {
562 type Output = Expr;
563
564 fn add(self, other: Expr) -> Expr {
565 let result = self.result + other.result;
566 Expr::new_binary(self, other, Operation::Add, result)
567 }
568}
569
570impl Add<f64> for Expr {
584 type Output = Expr;
585
586 fn add(self, other: f64) -> Expr {
587 let operand2 = Expr::new_leaf(other);
588 self + operand2
589 }
590}
591
592impl Add<Expr> for f64 {
606 type Output = Expr;
607
608 fn add(self, other: Expr) -> Expr {
609 let operand1 = Expr::new_leaf(self);
610 operand1 + other
611 }
612}
613
614impl Mul for Expr {
631 type Output = Expr;
632
633 fn mul(self, other: Expr) -> Expr {
634 let result = self.result * other.result;
635 Expr::new_binary(self, other, Operation::Mul, result)
636 }
637}
638
639impl Mul<f64> for Expr {
654 type Output = Expr;
655
656 fn mul(self, other: f64) -> Expr {
657 let operand2 = Expr::new_leaf(other);
658 self * operand2
659 }
660}
661
662impl Mul<Expr> for f64 {
677 type Output = Expr;
678
679 fn mul(self, other: Expr) -> Expr {
680 let operand1 = Expr::new_leaf(self);
681 operand1 * other
682 }
683}
684
685impl Sub for Expr {
702 type Output = Expr;
703
704 fn sub(self, other: Expr) -> Expr {
705 let result = self.result - other.result;
706 Expr::new_binary(self, other, Operation::Sub, result)
707 }
708}
709
710impl Sub<f64> for Expr {
725 type Output = Expr;
726
727 fn sub(self, other: f64) -> Expr {
728 let operand2 = Expr::new_leaf(other);
729 self - operand2
730 }
731}
732
733impl Sub<Expr> for f64 {
748 type Output = Expr;
749
750 fn sub(self, other: Expr) -> Expr {
751 let operand1 = Expr::new_leaf(self);
752 operand1 - other
753 }
754}
755
756impl Div for Expr {
773 type Output = Expr;
774
775 fn div(self, other: Expr) -> Expr {
776 let result = self.result / other.result;
777 Expr::new_binary(self, other, Operation::Div, result)
778 }
779}
780
781impl Div<f64> for Expr {
796 type Output = Expr;
797
798 fn div(self, other: f64) -> Expr {
799 let operand2 = Expr::new_leaf(other);
800 self / operand2
801 }
802}
803
804impl Sum for Expr {
825 fn sum<I>(iter: I) -> Self
826 where
827 I: Iterator<Item = Self>,
828 {
829 iter.reduce(|acc, x| acc + x)
830 .unwrap_or(Expr::new_leaf(0.0))
831 }
832}
833
834#[cfg(test)]
835mod tests {
836 use super::*;
837
838 fn assert_float_eq(f1: f64, f2: f64) {
839 let places = 7;
840 let tolerance = 10.0_f64.powi(-places);
841 assert!((f1 - f2).abs() < tolerance, "{} != {} (tol: {})", f1, f2, tolerance);
842 }
843
844 #[test]
845 fn test() {
846 let expr = Expr::new_leaf(1.0);
847 assert_eq!(expr.result, 1.0);
848 }
849
850 #[test]
851 fn test_unary() {
852 let expr = Expr::new_leaf(1.0);
853 let expr2 = Expr::new_unary(expr, Operation::Tanh, 1.1);
854
855 assert_eq!(expr2.result, 1.1);
856 assert_eq!(expr2.operand1.unwrap().result, 1.0);
857 }
858
859 #[test]
860 #[should_panic]
861 fn test_unary_expression_type_check() {
862 let expr = Expr::new_leaf(1.0);
863 let _expr2 = Expr::new_unary(expr, Operation::Add, 1.1);
864 }
865
866 #[test]
867 fn test_binary() {
868 let expr = Expr::new_leaf(1.0);
869 let expr2 = Expr::new_leaf(2.0);
870 let expr3 = Expr::new_binary(expr, expr2, Operation::Add, 1.1);
871
872 assert_eq!(expr3.result, 1.1);
873 assert_eq!(expr3.operand1.unwrap().result, 1.0);
874 assert_eq!(expr3.operand2.unwrap().result, 2.0);
875 }
876
877 #[test]
878 #[should_panic]
879 fn test_binary_expression_type_check() {
880 let expr = Expr::new_leaf(1.0);
881 let expr2 = Expr::new_leaf(2.0);
882 let _expr3 = Expr::new_binary(expr, expr2, Operation::Tanh, 3.0);
883 }
884
885 #[test]
886 fn test_mixed_tree() {
887 let expr = Expr::new_leaf(1.0);
888 let expr2 = Expr::new_leaf(2.0);
889 let expr3 = Expr::new_binary(expr, expr2, Operation::Sub, 1.1);
890 let expr4 = Expr::new_unary(expr3, Operation::Tanh, 1.2);
891
892 assert_eq!(expr4.result, 1.2);
893 let expr3 = expr4.operand1.unwrap();
894 assert_eq!(expr3.result, 1.1);
895 assert_eq!(expr3.operand1.unwrap().result, 1.0);
896 assert_eq!(expr3.operand2.unwrap().result, 2.0);
897 }
898
899 #[test]
900 fn test_tanh() {
901 let expr = Expr::new_leaf(1.0);
902 let expr2 = expr.tanh();
903
904 assert_eq!(expr2.result, 0.7615941559557649);
905 assert!(expr2.operand1.is_some());
906 assert_eq!(expr2.operand1.unwrap().result, 1.0);
907 assert_eq!(expr2.operation, Operation::Tanh);
908 assert!(expr2.operand2.is_none());
909
910 fn get_tanh(x: f64) -> f64 {
912 Expr::new_leaf(x).tanh().result
913 }
914
915 assert_float_eq(get_tanh(10.74), 0.9999999);
916 assert_float_eq(get_tanh(-10.74), -0.9999999);
917 assert_float_eq(get_tanh(0.0), 0.0);
918 }
919
920 #[test]
921 fn test_exp() {
922 let expr = Expr::new_leaf(1.0);
923 let expr2 = expr.exp();
924
925 assert_eq!(expr2.result, 2.718281828459045);
926 assert!(expr2.operand1.is_some());
927 assert_eq!(expr2.operand1.unwrap().result, 1.0);
928 assert_eq!(expr2.operation, Operation::Exp);
929 assert!(expr2.operand2.is_none());
930 }
931
932 #[test]
933 fn test_relu() {
934 let expr = Expr::new_leaf(-1.0);
936 let expr2 = expr.relu();
937
938 assert_eq!(expr2.result, 0.0);
939 assert!(expr2.operand1.is_some());
940 assert_eq!(expr2.operand1.unwrap().result, -1.0);
941 assert_eq!(expr2.operation, Operation::ReLU);
942 assert!(expr2.operand2.is_none());
943
944 let expr = Expr::new_leaf(1.0);
946 let expr2 = expr.relu();
947
948 assert_eq!(expr2.result, 1.0);
949 assert!(expr2.operand1.is_some());
950 assert_eq!(expr2.operand1.unwrap().result, 1.0);
951 assert_eq!(expr2.operation, Operation::ReLU);
952 assert!(expr2.operand2.is_none());
953 }
954
955 #[test]
956 fn test_pow() {
957 let expr = Expr::new_leaf(2.0);
958 let expr2 = Expr::new_leaf(3.0);
959 let result = expr.pow(expr2);
960
961 assert_eq!(result.result, 8.0);
962 assert!(result.operand1.is_some());
963 assert_eq!(result.operand1.unwrap().result, 2.0);
964 assert_eq!(result.operation, Operation::Pow);
965
966 assert!(result.operand2.is_some());
967 assert_eq!(result.operand2.unwrap().result, 3.0);
968 }
969
970 #[test]
971 fn test_add() {
972 let expr = Expr::new_leaf(1.0);
973 let expr2 = Expr::new_leaf(2.0);
974 let expr3 = expr + expr2;
975
976 assert_eq!(expr3.result, 3.0);
977 assert!(expr3.operand1.is_some());
978 assert_eq!(expr3.operand1.unwrap().result, 1.0);
979 assert!(expr3.operand2.is_some());
980 assert_eq!(expr3.operand2.unwrap().result, 2.0);
981 assert_eq!(expr3.operation, Operation::Add);
982 }
983
984 #[test]
985 fn test_add_f64() {
986 let expr = Expr::new_leaf(1.0);
987 let expr2 = expr + 2.0;
988
989 assert_eq!(expr2.result, 3.0);
990 assert!(expr2.operand1.is_some());
991 assert_eq!(expr2.operand1.unwrap().result, 1.0);
992 assert!(expr2.operand2.is_some());
993 assert_eq!(expr2.operand2.unwrap().result, 2.0);
994 assert_eq!(expr2.operation, Operation::Add);
995 }
996
997 #[test]
998 fn test_add_f64_expr() {
999 let expr = Expr::new_leaf(1.0);
1000 let expr2 = 2.0 + expr;
1001
1002 assert_eq!(expr2.result, 3.0);
1003 assert!(expr2.operand1.is_some());
1004 assert_eq!(expr2.operand1.unwrap().result, 2.0);
1005 assert!(expr2.operand2.is_some());
1006 assert_eq!(expr2.operand2.unwrap().result, 1.0);
1007 assert_eq!(expr2.operation, Operation::Add);
1008 }
1009
1010 #[test]
1011 fn test_mul() {
1012 let expr = Expr::new_leaf(2.0);
1013 let expr2 = Expr::new_leaf(3.0);
1014 let expr3 = expr * expr2;
1015
1016 assert_eq!(expr3.result, 6.0);
1017 assert!(expr3.operand1.is_some());
1018 assert_eq!(expr3.operand1.unwrap().result, 2.0);
1019 assert!(expr3.operand2.is_some());
1020 assert_eq!(expr3.operand2.unwrap().result, 3.0);
1021 assert_eq!(expr3.operation, Operation::Mul);
1022 }
1023
1024 #[test]
1025 fn test_mul_f64() {
1026 let expr = Expr::new_leaf(2.0);
1027 let expr2 = expr * 3.0;
1028
1029 assert_eq!(expr2.result, 6.0);
1030 assert!(expr2.operand1.is_some());
1031 assert_eq!(expr2.operand1.unwrap().result, 2.0);
1032 assert!(expr2.operand2.is_some());
1033 assert_eq!(expr2.operand2.unwrap().result, 3.0);
1034 assert_eq!(expr2.operation, Operation::Mul);
1035 }
1036
1037 #[test]
1038 fn test_mul_f64_expr() {
1039 let expr = Expr::new_leaf(2.0);
1040 let expr2 = 3.0 * expr;
1041
1042 assert_eq!(expr2.result, 6.0);
1043 assert!(expr2.operand1.is_some());
1044 assert_eq!(expr2.operand1.unwrap().result, 3.0);
1045 assert!(expr2.operand2.is_some());
1046 assert_eq!(expr2.operand2.unwrap().result, 2.0);
1047 assert_eq!(expr2.operation, Operation::Mul);
1048 }
1049
1050 #[test]
1051 fn test_sub() {
1052 let expr = Expr::new_leaf(2.0);
1053 let expr2 = Expr::new_leaf(3.0);
1054 let expr3 = expr - expr2;
1055
1056 assert_eq!(expr3.result, -1.0);
1057 assert!(expr3.operand1.is_some());
1058 assert_eq!(expr3.operand1.unwrap().result, 2.0);
1059 assert!(expr3.operand2.is_some());
1060 assert_eq!(expr3.operand2.unwrap().result, 3.0);
1061 assert_eq!(expr3.operation, Operation::Sub);
1062 }
1063
1064 #[test]
1065 fn test_sub_f64() {
1066 let expr = Expr::new_leaf(2.0);
1067 let expr2 = expr - 3.0;
1068
1069 assert_eq!(expr2.result, -1.0);
1070 assert!(expr2.operand1.is_some());
1071 assert_eq!(expr2.operand1.unwrap().result, 2.0);
1072 assert!(expr2.operand2.is_some());
1073 assert_eq!(expr2.operand2.unwrap().result, 3.0);
1074 assert_eq!(expr2.operation, Operation::Sub);
1075 }
1076
1077 #[test]
1078 fn test_sub_f64_expr() {
1079 let expr = Expr::new_leaf(2.0);
1080 let expr2 = 3.0 - expr;
1081
1082 assert_eq!(expr2.result, 1.0);
1083 assert!(expr2.operand1.is_some());
1084 assert_eq!(expr2.operand1.unwrap().result, 3.0);
1085 assert!(expr2.operand2.is_some());
1086 assert_eq!(expr2.operand2.unwrap().result, 2.0);
1087 assert_eq!(expr2.operation, Operation::Sub);
1088 }
1089
1090 #[test]
1091 fn test_div() {
1092 let expr = Expr::new_leaf(6.0);
1093 let expr2 = Expr::new_leaf(3.0);
1094 let expr3 = expr / expr2;
1095
1096 assert_eq!(expr3.result, 2.0);
1097 assert!(expr3.operand1.is_some());
1098 assert_eq!(expr3.operand1.unwrap().result, 6.0);
1099 assert!(expr3.operand2.is_some());
1100 assert_eq!(expr3.operand2.unwrap().result, 3.0);
1101 assert_eq!(expr3.operation, Operation::Div);
1102 }
1103
1104 #[test]
1105 fn test_div_f64() {
1106 let expr = Expr::new_leaf(6.0);
1107 let expr2 = expr / 3.0;
1108
1109 assert_eq!(expr2.result, 2.0);
1110 assert!(expr2.operand1.is_some());
1111 assert_eq!(expr2.operand1.unwrap().result, 6.0);
1112 assert!(expr2.operand2.is_some());
1113 assert_eq!(expr2.operand2.unwrap().result, 3.0);
1114 assert_eq!(expr2.operation, Operation::Div);
1115 }
1116
1117 #[test]
1118 fn test_log() {
1119 let expr = Expr::new_leaf(2.0);
1120 let expr2 = expr.log();
1121
1122 assert_eq!(expr2.result, 0.6931471805599453);
1123 assert!(expr2.operand1.is_some());
1124 assert_eq!(expr2.operand1.unwrap().result, 2.0);
1125 assert_eq!(expr2.operation, Operation::Log);
1126 assert!(expr2.operand2.is_none());
1127 }
1128
1129 #[test]
1130 fn test_neg() {
1131 let expr = Expr::new_leaf(2.0);
1132 let expr2 = expr.neg();
1133
1134 assert_eq!(expr2.result, -2.0);
1135 assert!(expr2.operand1.is_some());
1136 assert_eq!(expr2.operand1.unwrap().result, 2.0);
1137 assert_eq!(expr2.operation, Operation::Neg);
1138 assert!(expr2.operand2.is_none());
1139 }
1140
1141 #[test]
1142 fn test_backpropagation_add() {
1143 let operand1 = Expr::new_leaf(1.0);
1144 let operand2 = Expr::new_leaf(2.0);
1145 let mut expr3 = operand1 + operand2;
1146
1147 expr3.learn(1e-09);
1148
1149 let operand1 = expr3.operand1.unwrap();
1150 let operand2 = expr3.operand2.unwrap();
1151 assert_eq!(operand1.grad, 1.0);
1152 assert_eq!(operand2.grad, 1.0);
1153 }
1154
1155 #[test]
1156 fn test_backpropagation_sub() {
1157 let operand1 = Expr::new_leaf(1.0);
1158 let operand2 = Expr::new_leaf(2.0);
1159 let mut expr3 = operand1 - operand2;
1160
1161 expr3.learn(1e-09);
1162
1163 let operand1 = expr3.operand1.unwrap();
1164 let operand2 = expr3.operand2.unwrap();
1165 assert_eq!(operand1.grad, 1.0);
1166 assert_eq!(operand2.grad, -1.0);
1167 }
1168
1169 #[test]
1170 fn test_backpropagation_mul() {
1171 let operand1 = Expr::new_leaf(3.0);
1172 let operand2 = Expr::new_leaf(4.0);
1173 let mut expr3 = operand1 * operand2;
1174
1175 expr3.learn(1e-09);
1176
1177 let operand1 = expr3.operand1.unwrap();
1178 let operand2 = expr3.operand2.unwrap();
1179 assert_eq!(operand1.grad, 4.0);
1180 assert_eq!(operand2.grad, 3.0);
1181 }
1182
1183 #[test]
1184 fn test_backpropagation_div() {
1185 let operand1 = Expr::new_leaf(3.0);
1186 let operand2 = Expr::new_leaf(4.0);
1187 let mut expr3 = operand1 / operand2;
1188
1189 expr3.learn(1e-09);
1190
1191 let operand1 = expr3.operand1.unwrap();
1192 let operand2 = expr3.operand2.unwrap();
1193 assert_eq!(operand1.grad, 0.25);
1194 assert_eq!(operand2.grad, -0.1875);
1195 }
1196
1197 #[test]
1198 fn test_backpropagation_tanh() {
1199 let operand1 = Expr::new_leaf(0.0);
1200 let mut expr2 = operand1.tanh();
1201
1202 expr2.learn(1e-09);
1203
1204 let operand1 = expr2.operand1.unwrap();
1205 assert_eq!(operand1.grad, 1.0);
1206 }
1207
1208 #[test]
1209 fn test_backpropagation_relu() {
1210 let operand1 = Expr::new_leaf(-1.0);
1211 let mut expr2 = operand1.relu();
1212
1213 expr2.learn(1e-09);
1214
1215 let operand1 = expr2.operand1.unwrap();
1216 assert_eq!(operand1.grad, 0.0);
1217 }
1218
1219 #[test]
1220 fn test_backpropagation_exp() {
1221 let operand1 = Expr::new_leaf(0.0);
1222 let mut expr2 = operand1.exp();
1223
1224 expr2.learn(1e-09);
1225
1226 let operand1 = expr2.operand1.unwrap();
1227 assert_eq!(operand1.grad, 1.0);
1228 }
1229
1230 #[test]
1231 fn test_backpropagation_pow() {
1232 let operand1 = Expr::new_leaf(2.0);
1233 let operand2 = Expr::new_leaf(3.0);
1234 let mut expr3 = operand1.pow(operand2);
1235
1236 expr3.learn(1e-09);
1237
1238 let operand1 = expr3.operand1.unwrap();
1239 let operand2 = expr3.operand2.unwrap();
1240 assert_eq!(operand1.grad, 12.0);
1241 assert_eq!(operand2.grad, 5.545177444479562);
1242 }
1243
1244 #[test]
1245 fn test_backpropagation_mixed_tree() {
1246 let operand1 = Expr::new_leaf(1.0);
1247 let operand2 = Expr::new_leaf(2.0);
1248 let expr3 = operand1 + operand2;
1249 let mut expr4 = expr3.tanh();
1250
1251 expr4.learn(1e-09);
1252
1253 let expr3 = expr4.operand1.unwrap();
1254 let operand1 = expr3.operand1.unwrap();
1255 let operand2 = expr3.operand2.unwrap();
1256
1257 assert_eq!(expr3.grad, 0.009866037165440211);
1258 assert_eq!(operand1.grad, 0.009866037165440211);
1259 assert_eq!(operand2.grad, 0.009866037165440211);
1260 }
1261
1262 #[test]
1263 fn test_backpropagation_karpathys_example() {
1264 let x1 = Expr::new_leaf(2.0);
1265 let x2 = Expr::new_leaf(0.0);
1266 let w1 = Expr::new_leaf(-3.0);
1267 let w2 = Expr::new_leaf(1.0);
1268 let b = Expr::new_leaf(6.8813735870195432);
1269
1270 let x1w1 = x1 * w1;
1271 let x2w2 = x2 * w2;
1272 let x1w1_x2w2 = x1w1 + x2w2;
1273 let n = x1w1_x2w2 + b;
1274 let mut o = n.tanh();
1275
1276 o.learn(1e-09);
1277
1278 assert_eq!(o.operation, Operation::Tanh);
1279 assert_eq!(o.grad, 1.0);
1280
1281 let n = o.operand1.unwrap();
1282 assert_eq!(n.operation, Operation::Add);
1283 assert_float_eq(n.grad, 0.5);
1284
1285 let x1w1_x2w2 = n.operand1.unwrap();
1286 assert_eq!(x1w1_x2w2.operation, Operation::Add);
1287 assert_float_eq(x1w1_x2w2.grad, 0.5);
1288
1289 let b = n.operand2.unwrap();
1290 assert_eq!(b.operation, Operation::None);
1291 assert_float_eq(b.grad, 0.5);
1292
1293 let x1w1 = x1w1_x2w2.operand1.unwrap();
1294 assert_eq!(x1w1.operation, Operation::Mul);
1295 assert_float_eq(x1w1.grad, 0.5);
1296
1297 let x2w2 = x1w1_x2w2.operand2.unwrap();
1298 assert_eq!(x2w2.operation, Operation::Mul);
1299 assert_float_eq(x2w2.grad, 0.5);
1300
1301 let x1 = x1w1.operand1.unwrap();
1302 assert_eq!(x1.operation, Operation::None);
1303 assert_float_eq(x1.grad, -1.5);
1304
1305 let w1 = x1w1.operand2.unwrap();
1306 assert_eq!(w1.operation, Operation::None);
1307 assert_float_eq(w1.grad, 1.0);
1308
1309 let x2 = x2w2.operand1.unwrap();
1310 assert_eq!(x2.operation, Operation::None);
1311 assert_float_eq(x2.grad, 0.5);
1312
1313 let w2 = x2w2.operand2.unwrap();
1314 assert_eq!(w2.operation, Operation::None);
1315 assert_float_eq(w2.grad, 0.0);
1316 }
1317
1318 #[test]
1319 fn test_learn_simple() {
1320 let mut expr = Expr::new_leaf(1.0);
1321 expr.learn(1e-01);
1322
1323 assert_float_eq(expr.result, 0.9);
1324 }
1325
1326 #[test]
1327 fn test_learn_skips_non_learnable() {
1328 let mut expr = Expr::new_leaf(1.0);
1329 expr.is_learnable = false;
1330 expr.learn(1e-01);
1331
1332 assert_float_eq(expr.result, 1.0);
1333 }
1334
1335 #[test]
1336 fn test_find_simple() {
1337 let expr = Expr::new_leaf_with_name(1.0, "x");
1338 let expr2 = expr.tanh();
1339
1340 let found = expr2.find("x");
1341 assert!(found.is_some());
1342 assert_eq!(found.unwrap().name, Some("x".to_string()));
1343 }
1344
1345 #[test]
1346 fn test_find_not_found() {
1347 let expr = Expr::new_leaf_with_name(1.0, "x");
1348 let expr2 = expr.tanh();
1349
1350 let found = expr2.find("y");
1351 assert!(found.is_none());
1352 }
1353
1354 #[test]
1355 fn test_sum_iterator() {
1356 let expr = Expr::new_leaf(1.0);
1357 let expr2 = Expr::new_leaf(2.0);
1358 let expr3 = Expr::new_leaf(3.0);
1359
1360 let sum: Expr = vec![expr, expr2, expr3].into_iter().sum::<Expr>();
1361 assert_eq!(sum.result, 6.0);
1362 }
1363
1364 #[test]
1365 fn test_find_after_clone() {
1366 let expr = Expr::new_leaf_with_name(1.0, "x");
1367 let expr2 = expr.tanh();
1368 let expr2_clone = expr2.clone();
1369
1370 let found = expr2_clone.find("x");
1371 assert!(found.is_some());
1372 assert_eq!(found.unwrap().name, Some("x".to_string()));
1373 }
1374}