1#![deny(missing_docs)]
6use std::ops::{Add, Div, Mul, Sub};
7use std::iter::Sum;
8
9#[derive(Debug, Clone, PartialEq)]
10enum Operation {
11 None,
12 Add,
13 Sub,
14 Mul,
15 Div,
16 Tanh,
17 Exp,
18 Pow,
19 ReLU,
20}
21
22impl Operation {
23 fn assert_is_type(&self, expr_type: ExprType) {
24 match self {
25 Operation::None => assert_eq!(expr_type, ExprType::Leaf),
26 Operation::Tanh | Operation::Exp | Operation::ReLU => assert_eq!(expr_type, ExprType::Unary),
27 _ => assert_eq!(expr_type, ExprType::Binary),
28 }
29 }
30}
31
32#[derive(Debug, PartialEq)]
33enum ExprType {
34 Leaf,
35 Unary,
36 Binary,
37}
38
39#[derive(Debug, Clone)]
49pub struct Expr {
50 operand1: Option<Box<Expr>>,
51 operand2: Option<Box<Expr>>,
52 operation: Operation,
53 pub result: f64,
55 pub is_learnable: bool,
57 grad: f64,
58 pub name: String,
60}
61
62impl Expr {
63 pub fn new_leaf(value: f64, name: &str) -> Expr {
72 Expr {
73 operand1: None,
74 operand2: None,
75 operation: Operation::None,
76 result: value,
77 is_learnable: true,
78 grad: 0.0,
79 name: name.to_string(),
80 }
81 }
82
83 fn expr_type(&self) -> ExprType {
84 match self.operation {
85 Operation::None => ExprType::Leaf,
86 Operation::Tanh | Operation::Exp | Operation::ReLU => ExprType::Unary,
87 _ => ExprType::Binary,
88 }
89 }
90
91 fn new_unary(operand: Expr, operation: Operation, result: f64, name: &str) -> Expr {
92 operation.assert_is_type(ExprType::Unary);
93 Expr {
94 operand1: Some(Box::new(operand)),
95 operand2: None,
96 operation,
97 result,
98 is_learnable: true,
99 grad: 0.0,
100 name: name.to_string(),
101 }
102 }
103
104 fn new_binary(operand1: Expr, operand2: Expr, operation: Operation, result: f64, name: &str) -> Expr {
105 operation.assert_is_type(ExprType::Binary);
106 Expr {
107 operand1: Some(Box::new(operand1)),
108 operand2: Some(Box::new(operand2)),
109 operation,
110 result,
111 is_learnable: true,
112 grad: 0.0,
113 name: name.to_string(),
114 }
115 }
116
117 pub fn tanh(self, name: &str) -> Expr {
129 let result = self.result.tanh();
130 Expr::new_unary(self, Operation::Tanh, result, name)
131 }
132
133 pub fn relu(self, name: &str) -> Expr {
145 let result = self.result.max(0.0);
146 Expr::new_unary(self, Operation::ReLU, result, name)
147 }
148
149 pub fn exp(self, name: &str) -> Expr {
161 let result = self.result.exp();
162 Expr::new_unary(self, Operation::Exp, result, name)
163 }
164
165 pub fn pow(self, exponent: Expr, name: &str) -> Expr {
178 let result = self.result.powf(exponent.result);
179 Expr::new_binary(self, exponent, Operation::Pow, result, name)
180 }
181
182 pub fn recalculate(&mut self) {
197 match self.expr_type() {
198 ExprType::Leaf => {}
199 ExprType::Unary => {
200 let operand1 = self.operand1.as_mut().expect("Unary expression did not have an operand");
201 operand1.recalculate();
202
203 self.result = match self.operation {
204 Operation::Tanh => operand1.result.tanh(),
205 Operation::Exp => operand1.result.exp(),
206 Operation::ReLU => operand1.result.max(0.0),
207 _ => panic!("Invalid unary operation {:?}", self.operation),
208 };
209 }
210 ExprType::Binary => {
211 let operand1 = self.operand1.as_mut().expect("Binary expression did not have an operand");
212 let operand2 = self.operand2.as_mut().expect("Binary expression did not have a second operand");
213
214 operand1.recalculate();
215 operand2.recalculate();
216
217 self.result = match self.operation {
218 Operation::Add => operand1.result + operand2.result,
219 Operation::Sub => operand1.result - operand2.result,
220 Operation::Mul => operand1.result * operand2.result,
221 Operation::Div => operand1.result / operand2.result,
222 Operation::Pow => operand1.result.powf(operand2.result),
223 _ => panic!("Invalid binary operation: {:?}", self.operation),
224 };
225 }
226 }
227 }
228
229 pub fn learn(&mut self, learning_rate: f64) {
251 self.grad = 1.0;
252 self.learn_internal(learning_rate);
253 }
254
255 fn learn_internal(&mut self, learning_rate: f64) {
256 match self.expr_type() {
257 ExprType::Leaf => {
258 if self.is_learnable {
261 self.result -= learning_rate * self.grad;
262 }
263 }
264 ExprType::Unary => {
265 let operand1 = self.operand1.as_mut().expect("Unary expression did not have an operand");
266
267 match self.operation {
268 Operation::Tanh => {
269 let tanh_grad = 1.0 - (self.result * self.result);
270 operand1.grad = self.grad * tanh_grad;
271 }
272 Operation::Exp => {
273 operand1.grad = self.grad * self.result;
274 }
275 Operation::ReLU => {
276 operand1.grad = self.grad * if self.result > 0.0 { 1.0 } else { 0.0 };
277 }
278 _ => panic!("Invalid unary operation {:?}", self.operation),
279 }
280
281 operand1.learn_internal(learning_rate);
282 }
283 ExprType::Binary => {
284 let operand1 = self.operand1.as_mut().expect("Binary expression did not have an operand");
285 let operand2 = self.operand2.as_mut().expect("Binary expression did not have a second operand");
286
287 match self.operation {
288 Operation::Add => {
289 operand1.grad = self.grad;
290 operand2.grad = self.grad;
291 }
292 Operation::Sub => {
293 operand1.grad = self.grad;
294 operand2.grad = -self.grad;
295 }
296 Operation::Mul => {
297 let operand2_result = operand2.result;
298 let operand1_result = operand1.result;
299
300 operand1.grad = self.grad * operand2_result;
301 operand2.grad = self.grad * operand1_result;
302 }
303 Operation::Div => {
304 let operand2_result = operand2.result;
305 let operand1_result = operand1.result;
306
307 operand1.grad = self.grad / operand2_result;
308 operand2.grad = -self.grad * operand1_result / (operand2_result * operand2_result);
309 }
310 Operation::Pow => {
311 let exponent = operand2.result;
312 let base = operand1.result;
313
314 operand1.grad = self.grad * exponent * base.powf(exponent - 1.0);
315 operand2.grad = self.grad * base.powf(exponent) * base.ln();
316 }
317 _ => panic!("Invalid binary operation: {:?}", self.operation),
318 }
319
320 operand1.learn_internal(learning_rate);
321 operand2.learn_internal(learning_rate);
322 }
323 }
324 }
325
326 pub fn find(&self, name: &str) -> Option<&Expr> {
342 if self.name == name {
343 return Some(self);
344 }
345
346 match self.expr_type() {
347 ExprType::Leaf => None,
348 ExprType::Unary => {
349 let operand1 = self.operand1.as_ref().expect("Unary expression did not have an operand");
350 operand1.find(name)
351 }
352 ExprType::Binary => {
353 let operand1 = self.operand1.as_ref().expect("Binary expression did not have an operand");
354 let operand2 = self.operand2.as_ref().expect("Binary expression did not have a second operand");
355
356 let result = operand1.find(name);
357 if result.is_some() {
358 return result;
359 }
360
361 operand2.find(name)
362 }
363 }
364 }
365}
366
367impl Add for Expr {
383 type Output = Expr;
384
385 fn add(self, other: Expr) -> Expr {
386 let result = self.result + other.result;
387 let name = &format!("({} + {})", self.name, other.name);
388 Expr::new_binary(self, other, Operation::Add, result, name)
389 }
390}
391
392impl Add<f64> for Expr {
406 type Output = Expr;
407
408 fn add(self, other: f64) -> Expr {
409 let operand2 = Expr::new_leaf(other, &other.to_string());
410 self + operand2
411 }
412}
413
414impl Add<Expr> for f64 {
428 type Output = Expr;
429
430 fn add(self, other: Expr) -> Expr {
431 let operand1 = Expr::new_leaf(self, &self.to_string());
432 operand1 + other
433 }
434}
435
436impl Mul for Expr {
453 type Output = Expr;
454
455 fn mul(self, other: Expr) -> Expr {
456 let result = self.result * other.result;
457 let name = &format!("({} * {})", self.name, other.name);
458 Expr::new_binary(self, other, Operation::Mul, result, name)
459 }
460}
461
462impl Mul<f64> for Expr {
477 type Output = Expr;
478
479 fn mul(self, other: f64) -> Expr {
480 let operand2 = Expr::new_leaf(other, &other.to_string());
481 self * operand2
482 }
483}
484
485impl Mul<Expr> for f64 {
500 type Output = Expr;
501
502 fn mul(self, other: Expr) -> Expr {
503 let operand1 = Expr::new_leaf(self, &self.to_string());
504 operand1 * other
505 }
506}
507
508impl Sub for Expr {
525 type Output = Expr;
526
527 fn sub(self, other: Expr) -> Expr {
528 let result = self.result - other.result;
529 let name = &format!("({} - {})", self.name, other.name);
530 Expr::new_binary(self, other, Operation::Sub, result, name)
531 }
532}
533
534impl Sub<f64> for Expr {
549 type Output = Expr;
550
551 fn sub(self, other: f64) -> Expr {
552 let operand2 = Expr::new_leaf(other, &other.to_string());
553 self - operand2
554 }
555}
556
557impl Sub<Expr> for f64 {
572 type Output = Expr;
573
574 fn sub(self, other: Expr) -> Expr {
575 let operand1 = Expr::new_leaf(self, &self.to_string());
576 operand1 - other
577 }
578}
579
580impl Div for Expr {
597 type Output = Expr;
598
599 fn div(self, other: Expr) -> Expr {
600 let result = self.result / other.result;
601 let name = &format!("({} / {})", self.name, other.name);
602 Expr::new_binary(self, other, Operation::Div, result, name)
603 }
604}
605
606impl Div<f64> for Expr {
621 type Output = Expr;
622
623 fn div(self, other: f64) -> Expr {
624 let operand2 = Expr::new_leaf(other, &other.to_string());
625 self / operand2
626 }
627}
628
629impl Sum for Expr {
650 fn sum<I>(iter: I) -> Self
651 where
652 I: Iterator<Item = Self>,
653 {
654 iter.reduce(|acc, x| acc + x)
655 .unwrap_or(Expr::new_leaf(0.0, "0.0"))
656 }
657}
658
659#[cfg(test)]
660mod tests {
661 use super::*;
662
663 fn assert_float_eq(f1: f64, f2: f64) {
664 let places = 7;
665 let tolerance = 10.0_f64.powi(-places);
666 assert!((f1 - f2).abs() < tolerance, "{} != {} (tol: {})", f1, f2, tolerance);
667 }
668
669 #[test]
670 fn test() {
671 let expr = Expr::new_leaf(1.0, "x");
672 assert_eq!(expr.result, 1.0);
673 }
674
675 #[test]
676 fn test_unary() {
677 let expr = Expr::new_leaf(1.0, "x");
678 let expr2 = Expr::new_unary(expr, Operation::Tanh, 1.1, "tanh(x)");
679
680 assert_eq!(expr2.result, 1.1);
681 assert_eq!(expr2.operand1.unwrap().result, 1.0);
682 }
683
684 #[test]
685 #[should_panic]
686 fn test_unary_expression_type_check() {
687 let expr = Expr::new_leaf(1.0, "x");
688 let _expr2 = Expr::new_unary(expr, Operation::Add, 1.1, "tanh(x)");
689 }
690
691 #[test]
692 fn test_binary() {
693 let expr = Expr::new_leaf(1.0, "x");
694 let expr2 = Expr::new_leaf(2.0, "y");
695 let expr3 = Expr::new_binary(expr, expr2, Operation::Add, 1.1, "x + y");
696
697 assert_eq!(expr3.result, 1.1);
698 assert_eq!(expr3.operand1.unwrap().result, 1.0);
699 assert_eq!(expr3.operand2.unwrap().result, 2.0);
700 }
701
702 #[test]
703 #[should_panic]
704 fn test_binary_expression_type_check() {
705 let expr = Expr::new_leaf(1.0, "x");
706 let expr2 = Expr::new_leaf(2.0, "y");
707 let _expr3 = Expr::new_binary(expr, expr2, Operation::Tanh, 3.0, "x + y");
708 }
709
710 #[test]
711 fn test_mixed_tree() {
712 let expr = Expr::new_leaf(1.0, "x");
713 let expr2 = Expr::new_leaf(2.0, "y");
714 let expr3 = Expr::new_binary(expr, expr2, Operation::Sub, 1.1, "x - y");
715 let expr4 = Expr::new_unary(expr3, Operation::Tanh, 1.2, "tanh(x - y)");
716
717 assert_eq!(expr4.result, 1.2);
718 let expr3 = expr4.operand1.unwrap();
719 assert_eq!(expr3.result, 1.1);
720 assert_eq!(expr3.operand1.unwrap().result, 1.0);
721 assert_eq!(expr3.operand2.unwrap().result, 2.0);
722 }
723
724 #[test]
725 fn test_tanh() {
726 let expr = Expr::new_leaf(1.0, "x");
727 let expr2 = expr.tanh("tanh(x)");
728
729 assert_eq!(expr2.result, 0.7615941559557649);
730 assert!(expr2.operand1.is_some());
731 assert_eq!(expr2.operand1.unwrap().result, 1.0);
732 assert_eq!(expr2.operation, Operation::Tanh);
733 assert!(expr2.operand2.is_none());
734
735 fn get_tanh(x: f64) -> f64 {
737 Expr::new_leaf(x, "x").tanh("tanh(x)").result
738 }
739
740 assert_float_eq(get_tanh(10.74), 0.9999999);
741 assert_float_eq(get_tanh(-10.74), -0.9999999);
742 assert_float_eq(get_tanh(0.0), 0.0);
743 }
744
745 #[test]
746 fn test_exp() {
747 let expr = Expr::new_leaf(1.0, "x");
748 let expr2 = expr.exp("exp(x)");
749
750 assert_eq!(expr2.result, 2.718281828459045);
751 assert!(expr2.operand1.is_some());
752 assert_eq!(expr2.operand1.unwrap().result, 1.0);
753 assert_eq!(expr2.operation, Operation::Exp);
754 assert!(expr2.operand2.is_none());
755 }
756
757 #[test]
758 fn test_relu() {
759 let expr = Expr::new_leaf(-1.0, "x");
761 let expr2 = expr.relu("relu(x)");
762
763 assert_eq!(expr2.result, 0.0);
764 assert!(expr2.operand1.is_some());
765 assert_eq!(expr2.operand1.unwrap().result, -1.0);
766 assert_eq!(expr2.operation, Operation::ReLU);
767 assert!(expr2.operand2.is_none());
768
769 let expr = Expr::new_leaf(1.0, "x");
771 let expr2 = expr.relu("relu(x)");
772
773 assert_eq!(expr2.result, 1.0);
774 assert!(expr2.operand1.is_some());
775 assert_eq!(expr2.operand1.unwrap().result, 1.0);
776 assert_eq!(expr2.operation, Operation::ReLU);
777 assert!(expr2.operand2.is_none());
778 }
779
780 #[test]
781 fn test_pow() {
782 let expr = Expr::new_leaf(2.0, "x");
783 let expr2 = Expr::new_leaf(3.0, "y");
784 let result = expr.pow(expr2, "x^y");
785
786 assert_eq!(result.result, 8.0);
787 assert!(result.operand1.is_some());
788 assert_eq!(result.operand1.unwrap().result, 2.0);
789 assert_eq!(result.operation, Operation::Pow);
790
791 assert!(result.operand2.is_some());
792 assert_eq!(result.operand2.unwrap().result, 3.0);
793 }
794
795 #[test]
796 fn test_add() {
797 let expr = Expr::new_leaf(1.0, "x");
798 let expr2 = Expr::new_leaf(2.0, "y");
799 let expr3 = expr + expr2;
800
801 assert_eq!(expr3.result, 3.0);
802 assert!(expr3.operand1.is_some());
803 assert_eq!(expr3.operand1.unwrap().result, 1.0);
804 assert!(expr3.operand2.is_some());
805 assert_eq!(expr3.operand2.unwrap().result, 2.0);
806 assert_eq!(expr3.operation, Operation::Add);
807 assert_eq!(expr3.name, "(x + y)");
808 }
809
810 #[test]
811 fn test_add_f64() {
812 let expr = Expr::new_leaf(1.0, "x");
813 let expr2 = expr + 2.0;
814
815 assert_eq!(expr2.result, 3.0);
816 assert!(expr2.operand1.is_some());
817 assert_eq!(expr2.operand1.unwrap().result, 1.0);
818 assert!(expr2.operand2.is_some());
819 assert_eq!(expr2.operand2.unwrap().result, 2.0);
820 assert_eq!(expr2.operation, Operation::Add);
821 assert_eq!(expr2.name, "(x + 2)");
822 }
823
824 #[test]
825 fn test_add_f64_expr() {
826 let expr = Expr::new_leaf(1.0, "x");
827 let expr2 = 2.0 + expr;
828
829 assert_eq!(expr2.result, 3.0);
830 assert!(expr2.operand1.is_some());
831 assert_eq!(expr2.operand1.unwrap().result, 2.0);
832 assert!(expr2.operand2.is_some());
833 assert_eq!(expr2.operand2.unwrap().result, 1.0);
834 assert_eq!(expr2.operation, Operation::Add);
835 assert_eq!(expr2.name, "(2 + x)");
836 }
837
838 #[test]
839 fn test_mul() {
840 let expr = Expr::new_leaf(2.0, "x");
841 let expr2 = Expr::new_leaf(3.0, "y");
842 let expr3 = expr * expr2;
843
844 assert_eq!(expr3.result, 6.0);
845 assert!(expr3.operand1.is_some());
846 assert_eq!(expr3.operand1.unwrap().result, 2.0);
847 assert!(expr3.operand2.is_some());
848 assert_eq!(expr3.operand2.unwrap().result, 3.0);
849 assert_eq!(expr3.operation, Operation::Mul);
850 assert_eq!(expr3.name, "(x * y)");
851 }
852
853 #[test]
854 fn test_mul_f64() {
855 let expr = Expr::new_leaf(2.0, "x");
856 let expr2 = expr * 3.0;
857
858 assert_eq!(expr2.result, 6.0);
859 assert!(expr2.operand1.is_some());
860 assert_eq!(expr2.operand1.unwrap().result, 2.0);
861 assert!(expr2.operand2.is_some());
862 assert_eq!(expr2.operand2.unwrap().result, 3.0);
863 assert_eq!(expr2.operation, Operation::Mul);
864 assert_eq!(expr2.name, "(x * 3)");
865 }
866
867 #[test]
868 fn test_mul_f64_expr() {
869 let expr = Expr::new_leaf(2.0, "x");
870 let expr2 = 3.0 * expr;
871
872 assert_eq!(expr2.result, 6.0);
873 assert!(expr2.operand1.is_some());
874 assert_eq!(expr2.operand1.unwrap().result, 3.0);
875 assert!(expr2.operand2.is_some());
876 assert_eq!(expr2.operand2.unwrap().result, 2.0);
877 assert_eq!(expr2.operation, Operation::Mul);
878 assert_eq!(expr2.name, "(3 * x)");
879 }
880
881 #[test]
882 fn test_sub() {
883 let expr = Expr::new_leaf(2.0, "x");
884 let expr2 = Expr::new_leaf(3.0, "y");
885 let expr3 = expr - expr2;
886
887 assert_eq!(expr3.result, -1.0);
888 assert!(expr3.operand1.is_some());
889 assert_eq!(expr3.operand1.unwrap().result, 2.0);
890 assert!(expr3.operand2.is_some());
891 assert_eq!(expr3.operand2.unwrap().result, 3.0);
892 assert_eq!(expr3.operation, Operation::Sub);
893 assert_eq!(expr3.name, "(x - y)");
894 }
895
896 #[test]
897 fn test_sub_f64() {
898 let expr = Expr::new_leaf(2.0, "x");
899 let expr2 = expr - 3.0;
900
901 assert_eq!(expr2.result, -1.0);
902 assert!(expr2.operand1.is_some());
903 assert_eq!(expr2.operand1.unwrap().result, 2.0);
904 assert!(expr2.operand2.is_some());
905 assert_eq!(expr2.operand2.unwrap().result, 3.0);
906 assert_eq!(expr2.operation, Operation::Sub);
907 assert_eq!(expr2.name, "(x - 3)");
908 }
909
910 #[test]
911 fn test_sub_f64_expr() {
912 let expr = Expr::new_leaf(2.0, "x");
913 let expr2 = 3.0 - expr;
914
915 assert_eq!(expr2.result, 1.0);
916 assert!(expr2.operand1.is_some());
917 assert_eq!(expr2.operand1.unwrap().result, 3.0);
918 assert!(expr2.operand2.is_some());
919 assert_eq!(expr2.operand2.unwrap().result, 2.0);
920 assert_eq!(expr2.operation, Operation::Sub);
921 assert_eq!(expr2.name, "(3 - x)");
922 }
923
924 #[test]
925 fn test_div() {
926 let expr = Expr::new_leaf(6.0, "x");
927 let expr2 = Expr::new_leaf(3.0, "y");
928 let expr3 = expr / expr2;
929
930 assert_eq!(expr3.result, 2.0);
931 assert!(expr3.operand1.is_some());
932 assert_eq!(expr3.operand1.unwrap().result, 6.0);
933 assert!(expr3.operand2.is_some());
934 assert_eq!(expr3.operand2.unwrap().result, 3.0);
935 assert_eq!(expr3.operation, Operation::Div);
936 assert_eq!(expr3.name, "(x / y)");
937 }
938
939 #[test]
940 fn test_div_f64() {
941 let expr = Expr::new_leaf(6.0, "x");
942 let expr2 = expr / 3.0;
943
944 assert_eq!(expr2.result, 2.0);
945 assert!(expr2.operand1.is_some());
946 assert_eq!(expr2.operand1.unwrap().result, 6.0);
947 assert!(expr2.operand2.is_some());
948 assert_eq!(expr2.operand2.unwrap().result, 3.0);
949 assert_eq!(expr2.operation, Operation::Div);
950 assert_eq!(expr2.name, "(x / 3)");
951 }
952
953 #[test]
954 fn test_backpropagation_add() {
955 let operand1 = Expr::new_leaf(1.0, "x");
956 let operand2 = Expr::new_leaf(2.0, "y");
957 let mut expr3 = operand1 + operand2;
958
959 expr3.learn(1e-09);
960
961 let operand1 = expr3.operand1.unwrap();
962 let operand2 = expr3.operand2.unwrap();
963 assert_eq!(operand1.grad, 1.0);
964 assert_eq!(operand2.grad, 1.0);
965 }
966
967 #[test]
968 fn test_backpropagation_sub() {
969 let operand1 = Expr::new_leaf(1.0, "x");
970 let operand2 = Expr::new_leaf(2.0, "y");
971 let mut expr3 = operand1 - operand2;
972
973 expr3.learn(1e-09);
974
975 let operand1 = expr3.operand1.unwrap();
976 let operand2 = expr3.operand2.unwrap();
977 assert_eq!(operand1.grad, 1.0);
978 assert_eq!(operand2.grad, -1.0);
979 }
980
981 #[test]
982 fn test_backpropagation_mul() {
983 let operand1 = Expr::new_leaf(3.0, "x");
984 let operand2 = Expr::new_leaf(4.0, "y");
985 let mut expr3 = operand1 * operand2;
986
987 expr3.learn(1e-09);
988
989 let operand1 = expr3.operand1.unwrap();
990 let operand2 = expr3.operand2.unwrap();
991 assert_eq!(operand1.grad, 4.0);
992 assert_eq!(operand2.grad, 3.0);
993 }
994
995 #[test]
996 fn test_backpropagation_div() {
997 let operand1 = Expr::new_leaf(3.0, "x");
998 let operand2 = Expr::new_leaf(4.0, "y");
999 let mut expr3 = operand1 / operand2;
1000
1001 expr3.learn(1e-09);
1002
1003 let operand1 = expr3.operand1.unwrap();
1004 let operand2 = expr3.operand2.unwrap();
1005 assert_eq!(operand1.grad, 0.25);
1006 assert_eq!(operand2.grad, -0.1875);
1007 }
1008
1009 #[test]
1010 fn test_backpropagation_tanh() {
1011 let operand1 = Expr::new_leaf(0.0, "x");
1012 let mut expr2 = operand1.tanh("tanh(x)");
1013
1014 expr2.learn(1e-09);
1015
1016 let operand1 = expr2.operand1.unwrap();
1017 assert_eq!(operand1.grad, 1.0);
1018 }
1019
1020 #[test]
1021 fn test_backpropagation_relu() {
1022 let operand1 = Expr::new_leaf(-1.0, "x");
1023 let mut expr2 = operand1.relu("relu(x)");
1024
1025 expr2.learn(1e-09);
1026
1027 let operand1 = expr2.operand1.unwrap();
1028 assert_eq!(operand1.grad, 0.0);
1029 }
1030
1031 #[test]
1032 fn test_backpropagation_exp() {
1033 let operand1 = Expr::new_leaf(0.0, "x");
1034 let mut expr2 = operand1.exp("exp(x)");
1035
1036 expr2.learn(1e-09);
1037
1038 let operand1 = expr2.operand1.unwrap();
1039 assert_eq!(operand1.grad, 1.0);
1040 }
1041
1042 #[test]
1043 fn test_backpropagation_pow() {
1044 let operand1 = Expr::new_leaf(2.0, "x");
1045 let operand2 = Expr::new_leaf(3.0, "y");
1046 let mut expr3 = operand1.pow(operand2, "x^y");
1047
1048 expr3.learn(1e-09);
1049
1050 let operand1 = expr3.operand1.unwrap();
1051 let operand2 = expr3.operand2.unwrap();
1052 assert_eq!(operand1.grad, 12.0);
1053 assert_eq!(operand2.grad, 5.545177444479562);
1054 }
1055
1056 #[test]
1057 fn test_backpropagation_mixed_tree() {
1058 let operand1 = Expr::new_leaf(1.0, "x");
1059 let operand2 = Expr::new_leaf(2.0, "y");
1060 let expr3 = operand1 + operand2;
1061 let mut expr4 = expr3.tanh("tanh(x + y)");
1062
1063 expr4.learn(1e-09);
1064
1065 let expr3 = expr4.operand1.unwrap();
1066 let operand1 = expr3.operand1.unwrap();
1067 let operand2 = expr3.operand2.unwrap();
1068
1069 assert_eq!(expr3.grad, 0.009866037165440211);
1070 assert_eq!(operand1.grad, 0.009866037165440211);
1071 assert_eq!(operand2.grad, 0.009866037165440211);
1072 }
1073
1074 #[test]
1075 fn test_backpropagation_karpathys_example() {
1076 let x1 = Expr::new_leaf(2.0, "x1");
1077 let x2 = Expr::new_leaf(0.0, "x2");
1078 let w1 = Expr::new_leaf(-3.0, "w1");
1079 let w2 = Expr::new_leaf(1.0, "w2");
1080 let b = Expr::new_leaf(6.8813735870195432, "b");
1081
1082 let x1w1 = x1 * w1;
1083 let x2w2 = x2 * w2;
1084 let x1w1_x2w2 = x1w1 + x2w2;
1085 let n = x1w1_x2w2 + b;
1086 let mut o = n.tanh("tanh(n)");
1087
1088 o.learn(1e-09);
1089
1090 assert_eq!(o.operation, Operation::Tanh);
1091 assert_eq!(o.grad, 1.0);
1092
1093 let n = o.operand1.unwrap();
1094 assert_eq!(n.operation, Operation::Add);
1095 assert_float_eq(n.grad, 0.5);
1096
1097 let x1w1_x2w2 = n.operand1.unwrap();
1098 assert_eq!(x1w1_x2w2.operation, Operation::Add);
1099 assert_float_eq(x1w1_x2w2.grad, 0.5);
1100
1101 let b = n.operand2.unwrap();
1102 assert_eq!(b.operation, Operation::None);
1103 assert_float_eq(b.grad, 0.5);
1104
1105 let x1w1 = x1w1_x2w2.operand1.unwrap();
1106 assert_eq!(x1w1.operation, Operation::Mul);
1107 assert_float_eq(x1w1.grad, 0.5);
1108
1109 let x2w2 = x1w1_x2w2.operand2.unwrap();
1110 assert_eq!(x2w2.operation, Operation::Mul);
1111 assert_float_eq(x2w2.grad, 0.5);
1112
1113 let x1 = x1w1.operand1.unwrap();
1114 assert_eq!(x1.operation, Operation::None);
1115 assert_float_eq(x1.grad, -1.5);
1116
1117 let w1 = x1w1.operand2.unwrap();
1118 assert_eq!(w1.operation, Operation::None);
1119 assert_float_eq(w1.grad, 1.0);
1120
1121 let x2 = x2w2.operand1.unwrap();
1122 assert_eq!(x2.operation, Operation::None);
1123 assert_float_eq(x2.grad, 0.5);
1124
1125 let w2 = x2w2.operand2.unwrap();
1126 assert_eq!(w2.operation, Operation::None);
1127 assert_float_eq(w2.grad, 0.0);
1128 }
1129
1130 #[test]
1131 fn test_learn_simple() {
1132 let mut expr = Expr::new_leaf(1.0, "x");
1133 expr.learn(1e-01);
1134
1135 assert_float_eq(expr.result, 0.9);
1136 }
1137
1138 #[test]
1139 fn test_learn_skips_non_learnable() {
1140 let mut expr = Expr::new_leaf(1.0, "x");
1141 expr.is_learnable = false;
1142 expr.learn(1e-01);
1143
1144 assert_float_eq(expr.result, 1.0);
1145 }
1146
1147 #[test]
1148 fn test_find_simple() {
1149 let expr = Expr::new_leaf(1.0, "x");
1150 let expr2 = expr.tanh("tanh(x)");
1151
1152 let found = expr2.find("x");
1153 assert!(found.is_some());
1154 assert_eq!(found.unwrap().name, "x");
1155 }
1156
1157 #[test]
1158 fn test_find_not_found() {
1159 let expr = Expr::new_leaf(1.0, "x");
1160 let expr2 = expr.tanh("tanh(x)");
1161
1162 let found = expr2.find("y");
1163 assert!(found.is_none());
1164 }
1165
1166 #[test]
1167 fn test_sum_iterator() {
1168 let expr = Expr::new_leaf(1.0, "x");
1169 let expr2 = Expr::new_leaf(2.0, "y");
1170 let expr3 = Expr::new_leaf(3.0, "z");
1171
1172 let sum: Expr = vec![expr, expr2, expr3].into_iter().sum::<Expr>();
1173 assert_eq!(sum.result, 6.0);
1174 }
1175
1176 #[test]
1177 fn test_find_after_clone() {
1178 let expr = Expr::new_leaf(1.0, "x");
1179 let expr2 = expr.tanh("tanh(x)");
1180 let expr2_clone = expr2.clone();
1181
1182 let found = expr2_clone.find("x");
1183 assert!(found.is_some());
1184 assert_eq!(found.unwrap().name, "x");
1185 }
1186}