1#![deny(missing_docs)]
6use std::ops::{Add, Div, Mul, Sub};
7use std::iter::Sum;
8
9#[derive(Debug, Clone, PartialEq)]
10enum Operation {
11 None,
12 Add,
13 Sub,
14 Mul,
15 Div,
16 Tanh,
17 Exp,
18 Pow,
19 ReLU,
20 Log,
21 Neg,
22}
23
24impl Operation {
25 fn assert_is_type(&self, expr_type: ExprType) {
26 match self {
27 Operation::None => assert_eq!(expr_type, ExprType::Leaf),
28 Operation::Tanh | Operation::Exp | Operation::ReLU | Operation::Log | Operation::Neg => assert_eq!(expr_type, ExprType::Unary),
29 _ => assert_eq!(expr_type, ExprType::Binary),
30 }
31 }
32}
33
34#[derive(Debug, PartialEq)]
35enum ExprType {
36 Leaf,
37 Unary,
38 Binary,
39}
40
41#[derive(Debug, Clone)]
51pub struct Expr {
52 operand1: Option<Box<Expr>>,
53 operand2: Option<Box<Expr>>,
54 operation: Operation,
55 pub result: f64,
57 pub is_learnable: bool,
59 grad: f64,
60 pub name: String,
62}
63
64impl Expr {
65 pub fn new_leaf(value: f64, name: &str) -> Expr {
74 Expr {
75 operand1: None,
76 operand2: None,
77 operation: Operation::None,
78 result: value,
79 is_learnable: true,
80 grad: 0.0,
81 name: name.to_string(),
82 }
83 }
84
85 fn expr_type(&self) -> ExprType {
86 match self.operation {
87 Operation::None => ExprType::Leaf,
88 Operation::Tanh | Operation::Exp | Operation::ReLU | Operation::Log | Operation::Neg => ExprType::Unary,
89 _ => ExprType::Binary,
90 }
91 }
92
93 fn new_unary(operand: Expr, operation: Operation, result: f64, name: &str) -> Expr {
94 operation.assert_is_type(ExprType::Unary);
95 Expr {
96 operand1: Some(Box::new(operand)),
97 operand2: None,
98 operation,
99 result,
100 is_learnable: false,
101 grad: 0.0,
102 name: name.to_string(),
103 }
104 }
105
106 fn new_binary(operand1: Expr, operand2: Expr, operation: Operation, result: f64, name: &str) -> Expr {
107 operation.assert_is_type(ExprType::Binary);
108 Expr {
109 operand1: Some(Box::new(operand1)),
110 operand2: Some(Box::new(operand2)),
111 operation,
112 result,
113 is_learnable: false,
114 grad: 0.0,
115 name: name.to_string(),
116 }
117 }
118
119 pub fn tanh(self, name: &str) -> Expr {
131 let result = self.result.tanh();
132 Expr::new_unary(self, Operation::Tanh, result, name)
133 }
134
135 pub fn relu(self, name: &str) -> Expr {
147 let result = self.result.max(0.0);
148 Expr::new_unary(self, Operation::ReLU, result, name)
149 }
150
151 pub fn exp(self, name: &str) -> Expr {
163 let result = self.result.exp();
164 Expr::new_unary(self, Operation::Exp, result, name)
165 }
166
167 pub fn pow(self, exponent: Expr, name: &str) -> Expr {
180 let result = self.result.powf(exponent.result);
181 Expr::new_binary(self, exponent, Operation::Pow, result, name)
182 }
183
184 pub fn log(self, name: &str) -> Expr {
196 let result = self.result.ln();
197 Expr::new_unary(self, Operation::Log, result, name)
198 }
199
200 pub fn neg(self, name: &str) -> Expr {
212 let result = -self.result;
213 Expr::new_unary(self, Operation::Neg, result, name)
214 }
215
216 pub fn recalculate(&mut self) {
231 match self.expr_type() {
232 ExprType::Leaf => {}
233 ExprType::Unary => {
234 let operand1 = self.operand1.as_mut().expect("Unary expression did not have an operand");
235 operand1.recalculate();
236
237 self.result = match self.operation {
238 Operation::Tanh => operand1.result.tanh(),
239 Operation::Exp => operand1.result.exp(),
240 Operation::ReLU => operand1.result.max(0.0),
241 _ => panic!("Invalid unary operation {:?}", self.operation),
242 };
243 }
244 ExprType::Binary => {
245 let operand1 = self.operand1.as_mut().expect("Binary expression did not have an operand");
246 let operand2 = self.operand2.as_mut().expect("Binary expression did not have a second operand");
247
248 operand1.recalculate();
249 operand2.recalculate();
250
251 self.result = match self.operation {
252 Operation::Add => operand1.result + operand2.result,
253 Operation::Sub => operand1.result - operand2.result,
254 Operation::Mul => operand1.result * operand2.result,
255 Operation::Div => operand1.result / operand2.result,
256 Operation::Pow => operand1.result.powf(operand2.result),
257 _ => panic!("Invalid binary operation: {:?}", self.operation),
258 };
259 }
260 }
261 }
262
263 pub fn learn(&mut self, learning_rate: f64) {
285 self.grad = 1.0;
286 self.learn_internal(learning_rate);
287 }
288
289 fn learn_internal(&mut self, learning_rate: f64) {
290 match self.expr_type() {
291 ExprType::Leaf => {
292 if self.is_learnable {
295 self.result -= learning_rate * self.grad;
296 }
297 }
298 ExprType::Unary => {
299 let operand1 = self.operand1.as_mut().expect("Unary expression did not have an operand");
300
301 match self.operation {
302 Operation::Tanh => {
303 let tanh_grad = 1.0 - (self.result * self.result);
304 operand1.grad = self.grad * tanh_grad;
305 }
306 Operation::Exp => {
307 operand1.grad = self.grad * self.result;
308 }
309 Operation::ReLU => {
310 operand1.grad = self.grad * if self.result > 0.0 { 1.0 } else { 0.0 };
311 }
312 Operation::Log => {
313 operand1.grad = self.grad / operand1.result;
314 }
315 Operation::Neg => {
316 operand1.grad = -self.grad;
317 }
318 _ => panic!("Invalid unary operation {:?}", self.operation),
319 }
320
321 operand1.learn_internal(learning_rate);
322 }
323 ExprType::Binary => {
324 let operand1 = self.operand1.as_mut().expect("Binary expression did not have an operand");
325 let operand2 = self.operand2.as_mut().expect("Binary expression did not have a second operand");
326
327 match self.operation {
328 Operation::Add => {
329 operand1.grad = self.grad;
330 operand2.grad = self.grad;
331 }
332 Operation::Sub => {
333 operand1.grad = self.grad;
334 operand2.grad = -self.grad;
335 }
336 Operation::Mul => {
337 let operand2_result = operand2.result;
338 let operand1_result = operand1.result;
339
340 operand1.grad = self.grad * operand2_result;
341 operand2.grad = self.grad * operand1_result;
342 }
343 Operation::Div => {
344 let operand2_result = operand2.result;
345 let operand1_result = operand1.result;
346
347 operand1.grad = self.grad / operand2_result;
348 operand2.grad = -self.grad * operand1_result / (operand2_result * operand2_result);
349 }
350 Operation::Pow => {
351 let exponent = operand2.result;
352 let base = operand1.result;
353
354 operand1.grad = self.grad * exponent * base.powf(exponent - 1.0);
355 operand2.grad = self.grad * base.powf(exponent) * base.ln();
356 }
357 _ => panic!("Invalid binary operation: {:?}", self.operation),
358 }
359
360 operand1.learn_internal(learning_rate);
361 operand2.learn_internal(learning_rate);
362 }
363 }
364 }
365
366 pub fn find(&self, name: &str) -> Option<&Expr> {
382 if self.name == name {
383 return Some(self);
384 }
385
386 match self.expr_type() {
387 ExprType::Leaf => None,
388 ExprType::Unary => {
389 let operand1 = self.operand1.as_ref().expect("Unary expression did not have an operand");
390 operand1.find(name)
391 }
392 ExprType::Binary => {
393 let operand1 = self.operand1.as_ref().expect("Binary expression did not have an operand");
394 let operand2 = self.operand2.as_ref().expect("Binary expression did not have a second operand");
395
396 let result = operand1.find(name);
397 if result.is_some() {
398 return result;
399 }
400
401 operand2.find(name)
402 }
403 }
404 }
405}
406
407impl Add for Expr {
423 type Output = Expr;
424
425 fn add(self, other: Expr) -> Expr {
426 let result = self.result + other.result;
427 let name = &format!("({} + {})", self.name, other.name);
428 Expr::new_binary(self, other, Operation::Add, result, name)
429 }
430}
431
432impl Add<f64> for Expr {
446 type Output = Expr;
447
448 fn add(self, other: f64) -> Expr {
449 let operand2 = Expr::new_leaf(other, &other.to_string());
450 self + operand2
451 }
452}
453
454impl Add<Expr> for f64 {
468 type Output = Expr;
469
470 fn add(self, other: Expr) -> Expr {
471 let operand1 = Expr::new_leaf(self, &self.to_string());
472 operand1 + other
473 }
474}
475
476impl Mul for Expr {
493 type Output = Expr;
494
495 fn mul(self, other: Expr) -> Expr {
496 let result = self.result * other.result;
497 let name = &format!("({} * {})", self.name, other.name);
498 Expr::new_binary(self, other, Operation::Mul, result, name)
499 }
500}
501
502impl Mul<f64> for Expr {
517 type Output = Expr;
518
519 fn mul(self, other: f64) -> Expr {
520 let operand2 = Expr::new_leaf(other, &other.to_string());
521 self * operand2
522 }
523}
524
525impl Mul<Expr> for f64 {
540 type Output = Expr;
541
542 fn mul(self, other: Expr) -> Expr {
543 let operand1 = Expr::new_leaf(self, &self.to_string());
544 operand1 * other
545 }
546}
547
548impl Sub for Expr {
565 type Output = Expr;
566
567 fn sub(self, other: Expr) -> Expr {
568 let result = self.result - other.result;
569 let name = &format!("({} - {})", self.name, other.name);
570 Expr::new_binary(self, other, Operation::Sub, result, name)
571 }
572}
573
574impl Sub<f64> for Expr {
589 type Output = Expr;
590
591 fn sub(self, other: f64) -> Expr {
592 let operand2 = Expr::new_leaf(other, &other.to_string());
593 self - operand2
594 }
595}
596
597impl Sub<Expr> for f64 {
612 type Output = Expr;
613
614 fn sub(self, other: Expr) -> Expr {
615 let operand1 = Expr::new_leaf(self, &self.to_string());
616 operand1 - other
617 }
618}
619
620impl Div for Expr {
637 type Output = Expr;
638
639 fn div(self, other: Expr) -> Expr {
640 let result = self.result / other.result;
641 let name = &format!("({} / {})", self.name, other.name);
642 Expr::new_binary(self, other, Operation::Div, result, name)
643 }
644}
645
646impl Div<f64> for Expr {
661 type Output = Expr;
662
663 fn div(self, other: f64) -> Expr {
664 let operand2 = Expr::new_leaf(other, &other.to_string());
665 self / operand2
666 }
667}
668
669impl Sum for Expr {
690 fn sum<I>(iter: I) -> Self
691 where
692 I: Iterator<Item = Self>,
693 {
694 iter.reduce(|acc, x| acc + x)
695 .unwrap_or(Expr::new_leaf(0.0, "0.0"))
696 }
697}
698
699#[cfg(test)]
700mod tests {
701 use super::*;
702
703 fn assert_float_eq(f1: f64, f2: f64) {
704 let places = 7;
705 let tolerance = 10.0_f64.powi(-places);
706 assert!((f1 - f2).abs() < tolerance, "{} != {} (tol: {})", f1, f2, tolerance);
707 }
708
709 #[test]
710 fn test() {
711 let expr = Expr::new_leaf(1.0, "x");
712 assert_eq!(expr.result, 1.0);
713 }
714
715 #[test]
716 fn test_unary() {
717 let expr = Expr::new_leaf(1.0, "x");
718 let expr2 = Expr::new_unary(expr, Operation::Tanh, 1.1, "tanh(x)");
719
720 assert_eq!(expr2.result, 1.1);
721 assert_eq!(expr2.operand1.unwrap().result, 1.0);
722 }
723
724 #[test]
725 #[should_panic]
726 fn test_unary_expression_type_check() {
727 let expr = Expr::new_leaf(1.0, "x");
728 let _expr2 = Expr::new_unary(expr, Operation::Add, 1.1, "tanh(x)");
729 }
730
731 #[test]
732 fn test_binary() {
733 let expr = Expr::new_leaf(1.0, "x");
734 let expr2 = Expr::new_leaf(2.0, "y");
735 let expr3 = Expr::new_binary(expr, expr2, Operation::Add, 1.1, "x + y");
736
737 assert_eq!(expr3.result, 1.1);
738 assert_eq!(expr3.operand1.unwrap().result, 1.0);
739 assert_eq!(expr3.operand2.unwrap().result, 2.0);
740 }
741
742 #[test]
743 #[should_panic]
744 fn test_binary_expression_type_check() {
745 let expr = Expr::new_leaf(1.0, "x");
746 let expr2 = Expr::new_leaf(2.0, "y");
747 let _expr3 = Expr::new_binary(expr, expr2, Operation::Tanh, 3.0, "x + y");
748 }
749
750 #[test]
751 fn test_mixed_tree() {
752 let expr = Expr::new_leaf(1.0, "x");
753 let expr2 = Expr::new_leaf(2.0, "y");
754 let expr3 = Expr::new_binary(expr, expr2, Operation::Sub, 1.1, "x - y");
755 let expr4 = Expr::new_unary(expr3, Operation::Tanh, 1.2, "tanh(x - y)");
756
757 assert_eq!(expr4.result, 1.2);
758 let expr3 = expr4.operand1.unwrap();
759 assert_eq!(expr3.result, 1.1);
760 assert_eq!(expr3.operand1.unwrap().result, 1.0);
761 assert_eq!(expr3.operand2.unwrap().result, 2.0);
762 }
763
764 #[test]
765 fn test_tanh() {
766 let expr = Expr::new_leaf(1.0, "x");
767 let expr2 = expr.tanh("tanh(x)");
768
769 assert_eq!(expr2.result, 0.7615941559557649);
770 assert!(expr2.operand1.is_some());
771 assert_eq!(expr2.operand1.unwrap().result, 1.0);
772 assert_eq!(expr2.operation, Operation::Tanh);
773 assert!(expr2.operand2.is_none());
774
775 fn get_tanh(x: f64) -> f64 {
777 Expr::new_leaf(x, "x").tanh("tanh(x)").result
778 }
779
780 assert_float_eq(get_tanh(10.74), 0.9999999);
781 assert_float_eq(get_tanh(-10.74), -0.9999999);
782 assert_float_eq(get_tanh(0.0), 0.0);
783 }
784
785 #[test]
786 fn test_exp() {
787 let expr = Expr::new_leaf(1.0, "x");
788 let expr2 = expr.exp("exp(x)");
789
790 assert_eq!(expr2.result, 2.718281828459045);
791 assert!(expr2.operand1.is_some());
792 assert_eq!(expr2.operand1.unwrap().result, 1.0);
793 assert_eq!(expr2.operation, Operation::Exp);
794 assert!(expr2.operand2.is_none());
795 }
796
797 #[test]
798 fn test_relu() {
799 let expr = Expr::new_leaf(-1.0, "x");
801 let expr2 = expr.relu("relu(x)");
802
803 assert_eq!(expr2.result, 0.0);
804 assert!(expr2.operand1.is_some());
805 assert_eq!(expr2.operand1.unwrap().result, -1.0);
806 assert_eq!(expr2.operation, Operation::ReLU);
807 assert!(expr2.operand2.is_none());
808
809 let expr = Expr::new_leaf(1.0, "x");
811 let expr2 = expr.relu("relu(x)");
812
813 assert_eq!(expr2.result, 1.0);
814 assert!(expr2.operand1.is_some());
815 assert_eq!(expr2.operand1.unwrap().result, 1.0);
816 assert_eq!(expr2.operation, Operation::ReLU);
817 assert!(expr2.operand2.is_none());
818 }
819
820 #[test]
821 fn test_pow() {
822 let expr = Expr::new_leaf(2.0, "x");
823 let expr2 = Expr::new_leaf(3.0, "y");
824 let result = expr.pow(expr2, "x^y");
825
826 assert_eq!(result.result, 8.0);
827 assert!(result.operand1.is_some());
828 assert_eq!(result.operand1.unwrap().result, 2.0);
829 assert_eq!(result.operation, Operation::Pow);
830
831 assert!(result.operand2.is_some());
832 assert_eq!(result.operand2.unwrap().result, 3.0);
833 }
834
835 #[test]
836 fn test_add() {
837 let expr = Expr::new_leaf(1.0, "x");
838 let expr2 = Expr::new_leaf(2.0, "y");
839 let expr3 = expr + expr2;
840
841 assert_eq!(expr3.result, 3.0);
842 assert!(expr3.operand1.is_some());
843 assert_eq!(expr3.operand1.unwrap().result, 1.0);
844 assert!(expr3.operand2.is_some());
845 assert_eq!(expr3.operand2.unwrap().result, 2.0);
846 assert_eq!(expr3.operation, Operation::Add);
847 assert_eq!(expr3.name, "(x + y)");
848 }
849
850 #[test]
851 fn test_add_f64() {
852 let expr = Expr::new_leaf(1.0, "x");
853 let expr2 = expr + 2.0;
854
855 assert_eq!(expr2.result, 3.0);
856 assert!(expr2.operand1.is_some());
857 assert_eq!(expr2.operand1.unwrap().result, 1.0);
858 assert!(expr2.operand2.is_some());
859 assert_eq!(expr2.operand2.unwrap().result, 2.0);
860 assert_eq!(expr2.operation, Operation::Add);
861 assert_eq!(expr2.name, "(x + 2)");
862 }
863
864 #[test]
865 fn test_add_f64_expr() {
866 let expr = Expr::new_leaf(1.0, "x");
867 let expr2 = 2.0 + expr;
868
869 assert_eq!(expr2.result, 3.0);
870 assert!(expr2.operand1.is_some());
871 assert_eq!(expr2.operand1.unwrap().result, 2.0);
872 assert!(expr2.operand2.is_some());
873 assert_eq!(expr2.operand2.unwrap().result, 1.0);
874 assert_eq!(expr2.operation, Operation::Add);
875 assert_eq!(expr2.name, "(2 + x)");
876 }
877
878 #[test]
879 fn test_mul() {
880 let expr = Expr::new_leaf(2.0, "x");
881 let expr2 = Expr::new_leaf(3.0, "y");
882 let expr3 = expr * expr2;
883
884 assert_eq!(expr3.result, 6.0);
885 assert!(expr3.operand1.is_some());
886 assert_eq!(expr3.operand1.unwrap().result, 2.0);
887 assert!(expr3.operand2.is_some());
888 assert_eq!(expr3.operand2.unwrap().result, 3.0);
889 assert_eq!(expr3.operation, Operation::Mul);
890 assert_eq!(expr3.name, "(x * y)");
891 }
892
893 #[test]
894 fn test_mul_f64() {
895 let expr = Expr::new_leaf(2.0, "x");
896 let expr2 = expr * 3.0;
897
898 assert_eq!(expr2.result, 6.0);
899 assert!(expr2.operand1.is_some());
900 assert_eq!(expr2.operand1.unwrap().result, 2.0);
901 assert!(expr2.operand2.is_some());
902 assert_eq!(expr2.operand2.unwrap().result, 3.0);
903 assert_eq!(expr2.operation, Operation::Mul);
904 assert_eq!(expr2.name, "(x * 3)");
905 }
906
907 #[test]
908 fn test_mul_f64_expr() {
909 let expr = Expr::new_leaf(2.0, "x");
910 let expr2 = 3.0 * expr;
911
912 assert_eq!(expr2.result, 6.0);
913 assert!(expr2.operand1.is_some());
914 assert_eq!(expr2.operand1.unwrap().result, 3.0);
915 assert!(expr2.operand2.is_some());
916 assert_eq!(expr2.operand2.unwrap().result, 2.0);
917 assert_eq!(expr2.operation, Operation::Mul);
918 assert_eq!(expr2.name, "(3 * x)");
919 }
920
921 #[test]
922 fn test_sub() {
923 let expr = Expr::new_leaf(2.0, "x");
924 let expr2 = Expr::new_leaf(3.0, "y");
925 let expr3 = expr - expr2;
926
927 assert_eq!(expr3.result, -1.0);
928 assert!(expr3.operand1.is_some());
929 assert_eq!(expr3.operand1.unwrap().result, 2.0);
930 assert!(expr3.operand2.is_some());
931 assert_eq!(expr3.operand2.unwrap().result, 3.0);
932 assert_eq!(expr3.operation, Operation::Sub);
933 assert_eq!(expr3.name, "(x - y)");
934 }
935
936 #[test]
937 fn test_sub_f64() {
938 let expr = Expr::new_leaf(2.0, "x");
939 let expr2 = expr - 3.0;
940
941 assert_eq!(expr2.result, -1.0);
942 assert!(expr2.operand1.is_some());
943 assert_eq!(expr2.operand1.unwrap().result, 2.0);
944 assert!(expr2.operand2.is_some());
945 assert_eq!(expr2.operand2.unwrap().result, 3.0);
946 assert_eq!(expr2.operation, Operation::Sub);
947 assert_eq!(expr2.name, "(x - 3)");
948 }
949
950 #[test]
951 fn test_sub_f64_expr() {
952 let expr = Expr::new_leaf(2.0, "x");
953 let expr2 = 3.0 - expr;
954
955 assert_eq!(expr2.result, 1.0);
956 assert!(expr2.operand1.is_some());
957 assert_eq!(expr2.operand1.unwrap().result, 3.0);
958 assert!(expr2.operand2.is_some());
959 assert_eq!(expr2.operand2.unwrap().result, 2.0);
960 assert_eq!(expr2.operation, Operation::Sub);
961 assert_eq!(expr2.name, "(3 - x)");
962 }
963
964 #[test]
965 fn test_div() {
966 let expr = Expr::new_leaf(6.0, "x");
967 let expr2 = Expr::new_leaf(3.0, "y");
968 let expr3 = expr / expr2;
969
970 assert_eq!(expr3.result, 2.0);
971 assert!(expr3.operand1.is_some());
972 assert_eq!(expr3.operand1.unwrap().result, 6.0);
973 assert!(expr3.operand2.is_some());
974 assert_eq!(expr3.operand2.unwrap().result, 3.0);
975 assert_eq!(expr3.operation, Operation::Div);
976 assert_eq!(expr3.name, "(x / y)");
977 }
978
979 #[test]
980 fn test_div_f64() {
981 let expr = Expr::new_leaf(6.0, "x");
982 let expr2 = expr / 3.0;
983
984 assert_eq!(expr2.result, 2.0);
985 assert!(expr2.operand1.is_some());
986 assert_eq!(expr2.operand1.unwrap().result, 6.0);
987 assert!(expr2.operand2.is_some());
988 assert_eq!(expr2.operand2.unwrap().result, 3.0);
989 assert_eq!(expr2.operation, Operation::Div);
990 assert_eq!(expr2.name, "(x / 3)");
991 }
992
993 #[test]
994 fn test_log() {
995 let expr = Expr::new_leaf(2.0, "x");
996 let expr2 = expr.log("log(x)");
997
998 assert_eq!(expr2.result, 0.6931471805599453);
999 assert!(expr2.operand1.is_some());
1000 assert_eq!(expr2.operand1.unwrap().result, 2.0);
1001 assert_eq!(expr2.operation, Operation::Log);
1002 assert!(expr2.operand2.is_none());
1003 }
1004
1005 #[test]
1006 fn test_neg() {
1007 let expr = Expr::new_leaf(2.0, "x");
1008 let expr2 = expr.neg("neg(x)");
1009
1010 assert_eq!(expr2.result, -2.0);
1011 assert!(expr2.operand1.is_some());
1012 assert_eq!(expr2.operand1.unwrap().result, 2.0);
1013 assert_eq!(expr2.operation, Operation::Neg);
1014 assert!(expr2.operand2.is_none());
1015 }
1016
1017 #[test]
1018 fn test_backpropagation_add() {
1019 let operand1 = Expr::new_leaf(1.0, "x");
1020 let operand2 = Expr::new_leaf(2.0, "y");
1021 let mut expr3 = operand1 + operand2;
1022
1023 expr3.learn(1e-09);
1024
1025 let operand1 = expr3.operand1.unwrap();
1026 let operand2 = expr3.operand2.unwrap();
1027 assert_eq!(operand1.grad, 1.0);
1028 assert_eq!(operand2.grad, 1.0);
1029 }
1030
1031 #[test]
1032 fn test_backpropagation_sub() {
1033 let operand1 = Expr::new_leaf(1.0, "x");
1034 let operand2 = Expr::new_leaf(2.0, "y");
1035 let mut expr3 = operand1 - operand2;
1036
1037 expr3.learn(1e-09);
1038
1039 let operand1 = expr3.operand1.unwrap();
1040 let operand2 = expr3.operand2.unwrap();
1041 assert_eq!(operand1.grad, 1.0);
1042 assert_eq!(operand2.grad, -1.0);
1043 }
1044
1045 #[test]
1046 fn test_backpropagation_mul() {
1047 let operand1 = Expr::new_leaf(3.0, "x");
1048 let operand2 = Expr::new_leaf(4.0, "y");
1049 let mut expr3 = operand1 * operand2;
1050
1051 expr3.learn(1e-09);
1052
1053 let operand1 = expr3.operand1.unwrap();
1054 let operand2 = expr3.operand2.unwrap();
1055 assert_eq!(operand1.grad, 4.0);
1056 assert_eq!(operand2.grad, 3.0);
1057 }
1058
1059 #[test]
1060 fn test_backpropagation_div() {
1061 let operand1 = Expr::new_leaf(3.0, "x");
1062 let operand2 = Expr::new_leaf(4.0, "y");
1063 let mut expr3 = operand1 / operand2;
1064
1065 expr3.learn(1e-09);
1066
1067 let operand1 = expr3.operand1.unwrap();
1068 let operand2 = expr3.operand2.unwrap();
1069 assert_eq!(operand1.grad, 0.25);
1070 assert_eq!(operand2.grad, -0.1875);
1071 }
1072
1073 #[test]
1074 fn test_backpropagation_tanh() {
1075 let operand1 = Expr::new_leaf(0.0, "x");
1076 let mut expr2 = operand1.tanh("tanh(x)");
1077
1078 expr2.learn(1e-09);
1079
1080 let operand1 = expr2.operand1.unwrap();
1081 assert_eq!(operand1.grad, 1.0);
1082 }
1083
1084 #[test]
1085 fn test_backpropagation_relu() {
1086 let operand1 = Expr::new_leaf(-1.0, "x");
1087 let mut expr2 = operand1.relu("relu(x)");
1088
1089 expr2.learn(1e-09);
1090
1091 let operand1 = expr2.operand1.unwrap();
1092 assert_eq!(operand1.grad, 0.0);
1093 }
1094
1095 #[test]
1096 fn test_backpropagation_exp() {
1097 let operand1 = Expr::new_leaf(0.0, "x");
1098 let mut expr2 = operand1.exp("exp(x)");
1099
1100 expr2.learn(1e-09);
1101
1102 let operand1 = expr2.operand1.unwrap();
1103 assert_eq!(operand1.grad, 1.0);
1104 }
1105
1106 #[test]
1107 fn test_backpropagation_pow() {
1108 let operand1 = Expr::new_leaf(2.0, "x");
1109 let operand2 = Expr::new_leaf(3.0, "y");
1110 let mut expr3 = operand1.pow(operand2, "x^y");
1111
1112 expr3.learn(1e-09);
1113
1114 let operand1 = expr3.operand1.unwrap();
1115 let operand2 = expr3.operand2.unwrap();
1116 assert_eq!(operand1.grad, 12.0);
1117 assert_eq!(operand2.grad, 5.545177444479562);
1118 }
1119
1120 #[test]
1121 fn test_backpropagation_mixed_tree() {
1122 let operand1 = Expr::new_leaf(1.0, "x");
1123 let operand2 = Expr::new_leaf(2.0, "y");
1124 let expr3 = operand1 + operand2;
1125 let mut expr4 = expr3.tanh("tanh(x + y)");
1126
1127 expr4.learn(1e-09);
1128
1129 let expr3 = expr4.operand1.unwrap();
1130 let operand1 = expr3.operand1.unwrap();
1131 let operand2 = expr3.operand2.unwrap();
1132
1133 assert_eq!(expr3.grad, 0.009866037165440211);
1134 assert_eq!(operand1.grad, 0.009866037165440211);
1135 assert_eq!(operand2.grad, 0.009866037165440211);
1136 }
1137
1138 #[test]
1139 fn test_backpropagation_karpathys_example() {
1140 let x1 = Expr::new_leaf(2.0, "x1");
1141 let x2 = Expr::new_leaf(0.0, "x2");
1142 let w1 = Expr::new_leaf(-3.0, "w1");
1143 let w2 = Expr::new_leaf(1.0, "w2");
1144 let b = Expr::new_leaf(6.8813735870195432, "b");
1145
1146 let x1w1 = x1 * w1;
1147 let x2w2 = x2 * w2;
1148 let x1w1_x2w2 = x1w1 + x2w2;
1149 let n = x1w1_x2w2 + b;
1150 let mut o = n.tanh("tanh(n)");
1151
1152 o.learn(1e-09);
1153
1154 assert_eq!(o.operation, Operation::Tanh);
1155 assert_eq!(o.grad, 1.0);
1156
1157 let n = o.operand1.unwrap();
1158 assert_eq!(n.operation, Operation::Add);
1159 assert_float_eq(n.grad, 0.5);
1160
1161 let x1w1_x2w2 = n.operand1.unwrap();
1162 assert_eq!(x1w1_x2w2.operation, Operation::Add);
1163 assert_float_eq(x1w1_x2w2.grad, 0.5);
1164
1165 let b = n.operand2.unwrap();
1166 assert_eq!(b.operation, Operation::None);
1167 assert_float_eq(b.grad, 0.5);
1168
1169 let x1w1 = x1w1_x2w2.operand1.unwrap();
1170 assert_eq!(x1w1.operation, Operation::Mul);
1171 assert_float_eq(x1w1.grad, 0.5);
1172
1173 let x2w2 = x1w1_x2w2.operand2.unwrap();
1174 assert_eq!(x2w2.operation, Operation::Mul);
1175 assert_float_eq(x2w2.grad, 0.5);
1176
1177 let x1 = x1w1.operand1.unwrap();
1178 assert_eq!(x1.operation, Operation::None);
1179 assert_float_eq(x1.grad, -1.5);
1180
1181 let w1 = x1w1.operand2.unwrap();
1182 assert_eq!(w1.operation, Operation::None);
1183 assert_float_eq(w1.grad, 1.0);
1184
1185 let x2 = x2w2.operand1.unwrap();
1186 assert_eq!(x2.operation, Operation::None);
1187 assert_float_eq(x2.grad, 0.5);
1188
1189 let w2 = x2w2.operand2.unwrap();
1190 assert_eq!(w2.operation, Operation::None);
1191 assert_float_eq(w2.grad, 0.0);
1192 }
1193
1194 #[test]
1195 fn test_learn_simple() {
1196 let mut expr = Expr::new_leaf(1.0, "x");
1197 expr.learn(1e-01);
1198
1199 assert_float_eq(expr.result, 0.9);
1200 }
1201
1202 #[test]
1203 fn test_learn_skips_non_learnable() {
1204 let mut expr = Expr::new_leaf(1.0, "x");
1205 expr.is_learnable = false;
1206 expr.learn(1e-01);
1207
1208 assert_float_eq(expr.result, 1.0);
1209 }
1210
1211 #[test]
1212 fn test_find_simple() {
1213 let expr = Expr::new_leaf(1.0, "x");
1214 let expr2 = expr.tanh("tanh(x)");
1215
1216 let found = expr2.find("x");
1217 assert!(found.is_some());
1218 assert_eq!(found.unwrap().name, "x");
1219 }
1220
1221 #[test]
1222 fn test_find_not_found() {
1223 let expr = Expr::new_leaf(1.0, "x");
1224 let expr2 = expr.tanh("tanh(x)");
1225
1226 let found = expr2.find("y");
1227 assert!(found.is_none());
1228 }
1229
1230 #[test]
1231 fn test_sum_iterator() {
1232 let expr = Expr::new_leaf(1.0, "x");
1233 let expr2 = Expr::new_leaf(2.0, "y");
1234 let expr3 = Expr::new_leaf(3.0, "z");
1235
1236 let sum: Expr = vec![expr, expr2, expr3].into_iter().sum::<Expr>();
1237 assert_eq!(sum.result, 6.0);
1238 }
1239
1240 #[test]
1241 fn test_find_after_clone() {
1242 let expr = Expr::new_leaf(1.0, "x");
1243 let expr2 = expr.tanh("tanh(x)");
1244 let expr2_clone = expr2.clone();
1245
1246 let found = expr2_clone.find("x");
1247 assert!(found.is_some());
1248 assert_eq!(found.unwrap().name, "x");
1249 }
1250}