1#![deny(missing_docs)]
6use std::collections::VecDeque;
7use std::ops::{Add, Div, Mul, Sub};
8use std::iter::Sum;
9
10#[derive(Debug, Clone, PartialEq)]
11enum Operation {
12 None,
13 Add,
14 Sub,
15 Mul,
16 Div,
17 Tanh,
18 Exp,
19 Pow,
20 ReLU,
21 Log,
22 Neg,
23}
24
25impl Operation {
26 fn assert_is_type(&self, expr_type: ExprType) {
27 match self {
28 Operation::None => assert_eq!(expr_type, ExprType::Leaf),
29 Operation::Tanh | Operation::Exp | Operation::ReLU | Operation::Log | Operation::Neg => assert_eq!(expr_type, ExprType::Unary),
30 _ => assert_eq!(expr_type, ExprType::Binary),
31 }
32 }
33}
34
35#[derive(Debug, PartialEq)]
36enum ExprType {
37 Leaf,
38 Unary,
39 Binary,
40}
41
42#[derive(Debug, Clone)]
52pub struct Expr {
53 operand1: Option<Box<Expr>>,
54 operand2: Option<Box<Expr>>,
55 operation: Operation,
56 pub result: f64,
58 pub is_learnable: bool,
60 grad: f64,
61 pub name: String,
63}
64
65impl Expr {
66 pub fn new_leaf(value: f64, name: &str) -> Expr {
75 Expr {
76 operand1: None,
77 operand2: None,
78 operation: Operation::None,
79 result: value,
80 is_learnable: true,
81 grad: 0.0,
82 name: name.to_string(),
83 }
84 }
85
86 fn expr_type(&self) -> ExprType {
87 match self.operation {
88 Operation::None => ExprType::Leaf,
89 Operation::Tanh | Operation::Exp | Operation::ReLU | Operation::Log | Operation::Neg => ExprType::Unary,
90 _ => ExprType::Binary,
91 }
92 }
93
94 fn new_unary(operand: Expr, operation: Operation, result: f64, name: &str) -> Expr {
95 operation.assert_is_type(ExprType::Unary);
96 Expr {
97 operand1: Some(Box::new(operand)),
98 operand2: None,
99 operation,
100 result,
101 is_learnable: false,
102 grad: 0.0,
103 name: name.to_string(),
104 }
105 }
106
107 fn new_binary(operand1: Expr, operand2: Expr, operation: Operation, result: f64, name: &str) -> Expr {
108 operation.assert_is_type(ExprType::Binary);
109 Expr {
110 operand1: Some(Box::new(operand1)),
111 operand2: Some(Box::new(operand2)),
112 operation,
113 result,
114 is_learnable: false,
115 grad: 0.0,
116 name: name.to_string(),
117 }
118 }
119
120 pub fn tanh(self, name: &str) -> Expr {
132 let result = self.result.tanh();
133 Expr::new_unary(self, Operation::Tanh, result, name)
134 }
135
136 pub fn relu(self, name: &str) -> Expr {
148 let result = self.result.max(0.0);
149 Expr::new_unary(self, Operation::ReLU, result, name)
150 }
151
152 pub fn exp(self, name: &str) -> Expr {
164 let result = self.result.exp();
165 Expr::new_unary(self, Operation::Exp, result, name)
166 }
167
168 pub fn pow(self, exponent: Expr, name: &str) -> Expr {
181 let result = self.result.powf(exponent.result);
182 Expr::new_binary(self, exponent, Operation::Pow, result, name)
183 }
184
185 pub fn log(self, name: &str) -> Expr {
197 let result = self.result.ln();
198 Expr::new_unary(self, Operation::Log, result, name)
199 }
200
201 pub fn neg(self, name: &str) -> Expr {
213 let result = -self.result;
214 Expr::new_unary(self, Operation::Neg, result, name)
215 }
216
217 pub fn recalculate(&mut self) {
248 match self.expr_type() {
253 ExprType::Leaf => {}
254 ExprType::Unary => {
255 let operand1 = self.operand1.as_mut().expect("Unary expression did not have an operand");
256 operand1.recalculate();
257
258 self.result = match self.operation {
259 Operation::Tanh => operand1.result.tanh(),
260 Operation::Exp => operand1.result.exp(),
261 Operation::ReLU => operand1.result.max(0.0),
262 Operation::Log => operand1.result.ln(),
263 Operation::Neg => -operand1.result,
264 _ => panic!("Invalid unary operation {:?}", self.operation),
265 };
266 }
267 ExprType::Binary => {
268 let operand1 = self.operand1.as_mut().expect("Binary expression did not have an operand");
269 let operand2 = self.operand2.as_mut().expect("Binary expression did not have a second operand");
270
271 operand1.recalculate();
272 operand2.recalculate();
273
274 self.result = match self.operation {
275 Operation::Add => operand1.result + operand2.result,
276 Operation::Sub => operand1.result - operand2.result,
277 Operation::Mul => operand1.result * operand2.result,
278 Operation::Div => operand1.result / operand2.result,
279 Operation::Pow => operand1.result.powf(operand2.result),
280 _ => panic!("Invalid binary operation: {:?}", self.operation),
281 };
282 }
283 }
284 }
285
286 pub fn learn(&mut self, learning_rate: f64) {
308 self.grad = 1.0;
309
310 let mut queue = VecDeque::from([self]);
311
312 while let Some(node) = queue.pop_front() {
313 match node.expr_type() {
314 ExprType::Leaf => {
315 node.learn_internal_leaf(learning_rate);
316 }
317 ExprType::Unary => {
318 let operand1 = node.operand1.as_mut().expect("Unary expression did not have an operand");
319 operand1.adjust_grad_unary(&node.operation, node.grad, node.result);
320 queue.push_back(operand1);
321 }
322 ExprType::Binary => {
323 let operand1 = node.operand1.as_mut().expect("Binary expression did not have an operand");
324 let operand2 = node.operand2.as_mut().expect("Binary expression did not have a second operand");
325
326 operand1.adjust_grad_binary_op1(&node.operation, node.grad, operand2);
327 operand2.adjust_grad_binary_op2(&node.operation, node.grad, operand1);
328
329 queue.push_back(operand1);
330 queue.push_back(operand2);
331 }
332 }
333 }
334 }
335
336 fn learn_internal_leaf(&mut self, learning_rate: f64) {
337 if self.is_learnable {
340 self.result -= learning_rate * self.grad;
341 }
342 }
343
344 fn adjust_grad_unary(&mut self, child_operation: &Operation, child_grad: f64, child_result: f64) {
345 match child_operation {
346 Operation::Tanh => {
347 let tanh_grad = 1.0 - (child_result * child_result);
348 self.grad = child_grad * tanh_grad;
349 }
350 Operation::Exp => {
351 self.grad = child_grad * child_result;
352 }
353 Operation::ReLU => {
354 self.grad = child_grad * if child_result > 0.0 { 1.0 } else { 0.0 };
355 }
356 Operation::Log => {
357 self.grad = child_grad / child_result;
358 }
359 Operation::Neg => {
360 self.grad = -child_grad;
361 }
362 _ => panic!("Invalid unary operation {:?}", child_operation),
363 }
364 }
365
366 fn adjust_grad_binary_op1(&mut self, child_operation: &Operation, child_grad: f64, operand2: &Expr) {
367 match child_operation {
368 Operation::Add => {
369 self.grad = child_grad;
370 }
371 Operation::Sub => {
372 self.grad = child_grad;
373 }
374 Operation::Mul => {
375 let operand2_result = operand2.result;
376
377 self.grad = child_grad * operand2_result;
378 }
379 Operation::Div => {
380 let operand2_result = operand2.result;
381
382 self.grad = child_grad / operand2_result;
383 }
384 Operation::Pow => {
385 let exponent = operand2.result;
386 let base = self.result;
387
388 self.grad = child_grad * exponent * base.powf(exponent - 1.0);
389 }
390 _ => panic!("Invalid binary operation: {:?}", child_operation),
391 }
392 }
393
394 fn adjust_grad_binary_op2(&mut self,child_operation: &Operation, child_grad: f64, operand1: &Expr) {
395 match child_operation {
396 Operation::Add => {
397 self.grad = child_grad;
398 }
399 Operation::Sub => {
400 self.grad = -child_grad;
401 }
402 Operation::Mul => {
403 let operand1_result = operand1.result;
404 self.grad = child_grad * operand1_result;
405 }
406 Operation::Div => {
407 let operand2_result = self.result;
408 let operand1_result = operand1.result;
409
410 self.grad = -child_grad * operand1_result / (operand2_result * operand2_result);
411 }
412 Operation::Pow => {
413 let exponent = self.result;
414 let base = operand1.result;
415
416 self.grad = child_grad * base.powf(exponent) * base.ln();
417 }
418 _ => panic!("Invalid binary operation: {:?}", child_operation),
419 }
420 }
421
422 pub fn find(&self, name: &str) -> Option<&Expr> {
438 let mut stack = vec![self];
439
440 while let Some(node) = stack.pop() {
441 if node.name == name {
442 return Some(node);
443 }
444
445 if let Some(operand1) = node.operand1.as_ref() {
446 stack.push(operand1);
447 }
448 if let Some(operand2) = node.operand2.as_ref() {
449 stack.push(operand2);
450 }
451 }
452
453 None
454 }
455
456 pub fn find_mut(&mut self, name: &str) -> Option<&mut Expr> {
474 let mut stack = vec![self];
475
476 while let Some(node) = stack.pop() {
477 if node.name == name {
478 return Some(node);
479 }
480
481 if let Some(operand1) = node.operand1.as_mut() {
482 stack.push(operand1);
483 }
484 if let Some(operand2) = node.operand2.as_mut() {
485 stack.push(operand2);
486 }
487 }
488
489 None
490 }
491
492 pub fn parameter_count(&self, learnable_only: bool) -> usize {
508 let mut stack = vec![self];
509 let mut count = 0;
510
511 while let Some(node) = stack.pop() {
512 if node.is_learnable || !learnable_only {
513 count += 1;
514 }
515
516 if let Some(operand1) = node.operand1.as_ref() {
517 stack.push(operand1);
518 }
519 if let Some(operand2) = node.operand2.as_ref() {
520 stack.push(operand2);
521 }
522 }
523
524 count
525 }
526}
527
528impl Add for Expr {
544 type Output = Expr;
545
546 fn add(self, other: Expr) -> Expr {
547 let result = self.result + other.result;
548 let name = &format!("({} + {})", self.name, other.name);
549 Expr::new_binary(self, other, Operation::Add, result, name)
550 }
551}
552
553impl Add<f64> for Expr {
567 type Output = Expr;
568
569 fn add(self, other: f64) -> Expr {
570 let operand2 = Expr::new_leaf(other, &other.to_string());
571 self + operand2
572 }
573}
574
575impl Add<Expr> for f64 {
589 type Output = Expr;
590
591 fn add(self, other: Expr) -> Expr {
592 let operand1 = Expr::new_leaf(self, &self.to_string());
593 operand1 + other
594 }
595}
596
597impl Mul for Expr {
614 type Output = Expr;
615
616 fn mul(self, other: Expr) -> Expr {
617 let result = self.result * other.result;
618 let name = &format!("({} * {})", self.name, other.name);
619 Expr::new_binary(self, other, Operation::Mul, result, name)
620 }
621}
622
623impl Mul<f64> for Expr {
638 type Output = Expr;
639
640 fn mul(self, other: f64) -> Expr {
641 let operand2 = Expr::new_leaf(other, &other.to_string());
642 self * operand2
643 }
644}
645
646impl Mul<Expr> for f64 {
661 type Output = Expr;
662
663 fn mul(self, other: Expr) -> Expr {
664 let operand1 = Expr::new_leaf(self, &self.to_string());
665 operand1 * other
666 }
667}
668
669impl Sub for Expr {
686 type Output = Expr;
687
688 fn sub(self, other: Expr) -> Expr {
689 let result = self.result - other.result;
690 let name = &format!("({} - {})", self.name, other.name);
691 Expr::new_binary(self, other, Operation::Sub, result, name)
692 }
693}
694
695impl Sub<f64> for Expr {
710 type Output = Expr;
711
712 fn sub(self, other: f64) -> Expr {
713 let operand2 = Expr::new_leaf(other, &other.to_string());
714 self - operand2
715 }
716}
717
718impl Sub<Expr> for f64 {
733 type Output = Expr;
734
735 fn sub(self, other: Expr) -> Expr {
736 let operand1 = Expr::new_leaf(self, &self.to_string());
737 operand1 - other
738 }
739}
740
741impl Div for Expr {
758 type Output = Expr;
759
760 fn div(self, other: Expr) -> Expr {
761 let result = self.result / other.result;
762 let name = &format!("({} / {})", self.name, other.name);
763 Expr::new_binary(self, other, Operation::Div, result, name)
764 }
765}
766
767impl Div<f64> for Expr {
782 type Output = Expr;
783
784 fn div(self, other: f64) -> Expr {
785 let operand2 = Expr::new_leaf(other, &other.to_string());
786 self / operand2
787 }
788}
789
790impl Sum for Expr {
811 fn sum<I>(iter: I) -> Self
812 where
813 I: Iterator<Item = Self>,
814 {
815 iter.reduce(|acc, x| acc + x)
816 .unwrap_or(Expr::new_leaf(0.0, "0.0"))
817 }
818}
819
820#[cfg(test)]
821mod tests {
822 use super::*;
823
824 fn assert_float_eq(f1: f64, f2: f64) {
825 let places = 7;
826 let tolerance = 10.0_f64.powi(-places);
827 assert!((f1 - f2).abs() < tolerance, "{} != {} (tol: {})", f1, f2, tolerance);
828 }
829
830 #[test]
831 fn test() {
832 let expr = Expr::new_leaf(1.0, "x");
833 assert_eq!(expr.result, 1.0);
834 }
835
836 #[test]
837 fn test_unary() {
838 let expr = Expr::new_leaf(1.0, "x");
839 let expr2 = Expr::new_unary(expr, Operation::Tanh, 1.1, "tanh(x)");
840
841 assert_eq!(expr2.result, 1.1);
842 assert_eq!(expr2.operand1.unwrap().result, 1.0);
843 }
844
845 #[test]
846 #[should_panic]
847 fn test_unary_expression_type_check() {
848 let expr = Expr::new_leaf(1.0, "x");
849 let _expr2 = Expr::new_unary(expr, Operation::Add, 1.1, "tanh(x)");
850 }
851
852 #[test]
853 fn test_binary() {
854 let expr = Expr::new_leaf(1.0, "x");
855 let expr2 = Expr::new_leaf(2.0, "y");
856 let expr3 = Expr::new_binary(expr, expr2, Operation::Add, 1.1, "x + y");
857
858 assert_eq!(expr3.result, 1.1);
859 assert_eq!(expr3.operand1.unwrap().result, 1.0);
860 assert_eq!(expr3.operand2.unwrap().result, 2.0);
861 }
862
863 #[test]
864 #[should_panic]
865 fn test_binary_expression_type_check() {
866 let expr = Expr::new_leaf(1.0, "x");
867 let expr2 = Expr::new_leaf(2.0, "y");
868 let _expr3 = Expr::new_binary(expr, expr2, Operation::Tanh, 3.0, "x + y");
869 }
870
871 #[test]
872 fn test_mixed_tree() {
873 let expr = Expr::new_leaf(1.0, "x");
874 let expr2 = Expr::new_leaf(2.0, "y");
875 let expr3 = Expr::new_binary(expr, expr2, Operation::Sub, 1.1, "x - y");
876 let expr4 = Expr::new_unary(expr3, Operation::Tanh, 1.2, "tanh(x - y)");
877
878 assert_eq!(expr4.result, 1.2);
879 let expr3 = expr4.operand1.unwrap();
880 assert_eq!(expr3.result, 1.1);
881 assert_eq!(expr3.operand1.unwrap().result, 1.0);
882 assert_eq!(expr3.operand2.unwrap().result, 2.0);
883 }
884
885 #[test]
886 fn test_tanh() {
887 let expr = Expr::new_leaf(1.0, "x");
888 let expr2 = expr.tanh("tanh(x)");
889
890 assert_eq!(expr2.result, 0.7615941559557649);
891 assert!(expr2.operand1.is_some());
892 assert_eq!(expr2.operand1.unwrap().result, 1.0);
893 assert_eq!(expr2.operation, Operation::Tanh);
894 assert!(expr2.operand2.is_none());
895
896 fn get_tanh(x: f64) -> f64 {
898 Expr::new_leaf(x, "x").tanh("tanh(x)").result
899 }
900
901 assert_float_eq(get_tanh(10.74), 0.9999999);
902 assert_float_eq(get_tanh(-10.74), -0.9999999);
903 assert_float_eq(get_tanh(0.0), 0.0);
904 }
905
906 #[test]
907 fn test_exp() {
908 let expr = Expr::new_leaf(1.0, "x");
909 let expr2 = expr.exp("exp(x)");
910
911 assert_eq!(expr2.result, 2.718281828459045);
912 assert!(expr2.operand1.is_some());
913 assert_eq!(expr2.operand1.unwrap().result, 1.0);
914 assert_eq!(expr2.operation, Operation::Exp);
915 assert!(expr2.operand2.is_none());
916 }
917
918 #[test]
919 fn test_relu() {
920 let expr = Expr::new_leaf(-1.0, "x");
922 let expr2 = expr.relu("relu(x)");
923
924 assert_eq!(expr2.result, 0.0);
925 assert!(expr2.operand1.is_some());
926 assert_eq!(expr2.operand1.unwrap().result, -1.0);
927 assert_eq!(expr2.operation, Operation::ReLU);
928 assert!(expr2.operand2.is_none());
929
930 let expr = Expr::new_leaf(1.0, "x");
932 let expr2 = expr.relu("relu(x)");
933
934 assert_eq!(expr2.result, 1.0);
935 assert!(expr2.operand1.is_some());
936 assert_eq!(expr2.operand1.unwrap().result, 1.0);
937 assert_eq!(expr2.operation, Operation::ReLU);
938 assert!(expr2.operand2.is_none());
939 }
940
941 #[test]
942 fn test_pow() {
943 let expr = Expr::new_leaf(2.0, "x");
944 let expr2 = Expr::new_leaf(3.0, "y");
945 let result = expr.pow(expr2, "x^y");
946
947 assert_eq!(result.result, 8.0);
948 assert!(result.operand1.is_some());
949 assert_eq!(result.operand1.unwrap().result, 2.0);
950 assert_eq!(result.operation, Operation::Pow);
951
952 assert!(result.operand2.is_some());
953 assert_eq!(result.operand2.unwrap().result, 3.0);
954 }
955
956 #[test]
957 fn test_add() {
958 let expr = Expr::new_leaf(1.0, "x");
959 let expr2 = Expr::new_leaf(2.0, "y");
960 let expr3 = expr + expr2;
961
962 assert_eq!(expr3.result, 3.0);
963 assert!(expr3.operand1.is_some());
964 assert_eq!(expr3.operand1.unwrap().result, 1.0);
965 assert!(expr3.operand2.is_some());
966 assert_eq!(expr3.operand2.unwrap().result, 2.0);
967 assert_eq!(expr3.operation, Operation::Add);
968 assert_eq!(expr3.name, "(x + y)");
969 }
970
971 #[test]
972 fn test_add_f64() {
973 let expr = Expr::new_leaf(1.0, "x");
974 let expr2 = expr + 2.0;
975
976 assert_eq!(expr2.result, 3.0);
977 assert!(expr2.operand1.is_some());
978 assert_eq!(expr2.operand1.unwrap().result, 1.0);
979 assert!(expr2.operand2.is_some());
980 assert_eq!(expr2.operand2.unwrap().result, 2.0);
981 assert_eq!(expr2.operation, Operation::Add);
982 assert_eq!(expr2.name, "(x + 2)");
983 }
984
985 #[test]
986 fn test_add_f64_expr() {
987 let expr = Expr::new_leaf(1.0, "x");
988 let expr2 = 2.0 + expr;
989
990 assert_eq!(expr2.result, 3.0);
991 assert!(expr2.operand1.is_some());
992 assert_eq!(expr2.operand1.unwrap().result, 2.0);
993 assert!(expr2.operand2.is_some());
994 assert_eq!(expr2.operand2.unwrap().result, 1.0);
995 assert_eq!(expr2.operation, Operation::Add);
996 assert_eq!(expr2.name, "(2 + x)");
997 }
998
999 #[test]
1000 fn test_mul() {
1001 let expr = Expr::new_leaf(2.0, "x");
1002 let expr2 = Expr::new_leaf(3.0, "y");
1003 let expr3 = expr * expr2;
1004
1005 assert_eq!(expr3.result, 6.0);
1006 assert!(expr3.operand1.is_some());
1007 assert_eq!(expr3.operand1.unwrap().result, 2.0);
1008 assert!(expr3.operand2.is_some());
1009 assert_eq!(expr3.operand2.unwrap().result, 3.0);
1010 assert_eq!(expr3.operation, Operation::Mul);
1011 assert_eq!(expr3.name, "(x * y)");
1012 }
1013
1014 #[test]
1015 fn test_mul_f64() {
1016 let expr = Expr::new_leaf(2.0, "x");
1017 let expr2 = expr * 3.0;
1018
1019 assert_eq!(expr2.result, 6.0);
1020 assert!(expr2.operand1.is_some());
1021 assert_eq!(expr2.operand1.unwrap().result, 2.0);
1022 assert!(expr2.operand2.is_some());
1023 assert_eq!(expr2.operand2.unwrap().result, 3.0);
1024 assert_eq!(expr2.operation, Operation::Mul);
1025 assert_eq!(expr2.name, "(x * 3)");
1026 }
1027
1028 #[test]
1029 fn test_mul_f64_expr() {
1030 let expr = Expr::new_leaf(2.0, "x");
1031 let expr2 = 3.0 * expr;
1032
1033 assert_eq!(expr2.result, 6.0);
1034 assert!(expr2.operand1.is_some());
1035 assert_eq!(expr2.operand1.unwrap().result, 3.0);
1036 assert!(expr2.operand2.is_some());
1037 assert_eq!(expr2.operand2.unwrap().result, 2.0);
1038 assert_eq!(expr2.operation, Operation::Mul);
1039 assert_eq!(expr2.name, "(3 * x)");
1040 }
1041
1042 #[test]
1043 fn test_sub() {
1044 let expr = Expr::new_leaf(2.0, "x");
1045 let expr2 = Expr::new_leaf(3.0, "y");
1046 let expr3 = expr - expr2;
1047
1048 assert_eq!(expr3.result, -1.0);
1049 assert!(expr3.operand1.is_some());
1050 assert_eq!(expr3.operand1.unwrap().result, 2.0);
1051 assert!(expr3.operand2.is_some());
1052 assert_eq!(expr3.operand2.unwrap().result, 3.0);
1053 assert_eq!(expr3.operation, Operation::Sub);
1054 assert_eq!(expr3.name, "(x - y)");
1055 }
1056
1057 #[test]
1058 fn test_sub_f64() {
1059 let expr = Expr::new_leaf(2.0, "x");
1060 let expr2 = expr - 3.0;
1061
1062 assert_eq!(expr2.result, -1.0);
1063 assert!(expr2.operand1.is_some());
1064 assert_eq!(expr2.operand1.unwrap().result, 2.0);
1065 assert!(expr2.operand2.is_some());
1066 assert_eq!(expr2.operand2.unwrap().result, 3.0);
1067 assert_eq!(expr2.operation, Operation::Sub);
1068 assert_eq!(expr2.name, "(x - 3)");
1069 }
1070
1071 #[test]
1072 fn test_sub_f64_expr() {
1073 let expr = Expr::new_leaf(2.0, "x");
1074 let expr2 = 3.0 - expr;
1075
1076 assert_eq!(expr2.result, 1.0);
1077 assert!(expr2.operand1.is_some());
1078 assert_eq!(expr2.operand1.unwrap().result, 3.0);
1079 assert!(expr2.operand2.is_some());
1080 assert_eq!(expr2.operand2.unwrap().result, 2.0);
1081 assert_eq!(expr2.operation, Operation::Sub);
1082 assert_eq!(expr2.name, "(3 - x)");
1083 }
1084
1085 #[test]
1086 fn test_div() {
1087 let expr = Expr::new_leaf(6.0, "x");
1088 let expr2 = Expr::new_leaf(3.0, "y");
1089 let expr3 = expr / expr2;
1090
1091 assert_eq!(expr3.result, 2.0);
1092 assert!(expr3.operand1.is_some());
1093 assert_eq!(expr3.operand1.unwrap().result, 6.0);
1094 assert!(expr3.operand2.is_some());
1095 assert_eq!(expr3.operand2.unwrap().result, 3.0);
1096 assert_eq!(expr3.operation, Operation::Div);
1097 assert_eq!(expr3.name, "(x / y)");
1098 }
1099
1100 #[test]
1101 fn test_div_f64() {
1102 let expr = Expr::new_leaf(6.0, "x");
1103 let expr2 = expr / 3.0;
1104
1105 assert_eq!(expr2.result, 2.0);
1106 assert!(expr2.operand1.is_some());
1107 assert_eq!(expr2.operand1.unwrap().result, 6.0);
1108 assert!(expr2.operand2.is_some());
1109 assert_eq!(expr2.operand2.unwrap().result, 3.0);
1110 assert_eq!(expr2.operation, Operation::Div);
1111 assert_eq!(expr2.name, "(x / 3)");
1112 }
1113
1114 #[test]
1115 fn test_log() {
1116 let expr = Expr::new_leaf(2.0, "x");
1117 let expr2 = expr.log("log(x)");
1118
1119 assert_eq!(expr2.result, 0.6931471805599453);
1120 assert!(expr2.operand1.is_some());
1121 assert_eq!(expr2.operand1.unwrap().result, 2.0);
1122 assert_eq!(expr2.operation, Operation::Log);
1123 assert!(expr2.operand2.is_none());
1124 }
1125
1126 #[test]
1127 fn test_neg() {
1128 let expr = Expr::new_leaf(2.0, "x");
1129 let expr2 = expr.neg("neg(x)");
1130
1131 assert_eq!(expr2.result, -2.0);
1132 assert!(expr2.operand1.is_some());
1133 assert_eq!(expr2.operand1.unwrap().result, 2.0);
1134 assert_eq!(expr2.operation, Operation::Neg);
1135 assert!(expr2.operand2.is_none());
1136 }
1137
1138 #[test]
1139 fn test_backpropagation_add() {
1140 let operand1 = Expr::new_leaf(1.0, "x");
1141 let operand2 = Expr::new_leaf(2.0, "y");
1142 let mut expr3 = operand1 + operand2;
1143
1144 expr3.learn(1e-09);
1145
1146 let operand1 = expr3.operand1.unwrap();
1147 let operand2 = expr3.operand2.unwrap();
1148 assert_eq!(operand1.grad, 1.0);
1149 assert_eq!(operand2.grad, 1.0);
1150 }
1151
1152 #[test]
1153 fn test_backpropagation_sub() {
1154 let operand1 = Expr::new_leaf(1.0, "x");
1155 let operand2 = Expr::new_leaf(2.0, "y");
1156 let mut expr3 = operand1 - operand2;
1157
1158 expr3.learn(1e-09);
1159
1160 let operand1 = expr3.operand1.unwrap();
1161 let operand2 = expr3.operand2.unwrap();
1162 assert_eq!(operand1.grad, 1.0);
1163 assert_eq!(operand2.grad, -1.0);
1164 }
1165
1166 #[test]
1167 fn test_backpropagation_mul() {
1168 let operand1 = Expr::new_leaf(3.0, "x");
1169 let operand2 = Expr::new_leaf(4.0, "y");
1170 let mut expr3 = operand1 * operand2;
1171
1172 expr3.learn(1e-09);
1173
1174 let operand1 = expr3.operand1.unwrap();
1175 let operand2 = expr3.operand2.unwrap();
1176 assert_eq!(operand1.grad, 4.0);
1177 assert_eq!(operand2.grad, 3.0);
1178 }
1179
1180 #[test]
1181 fn test_backpropagation_div() {
1182 let operand1 = Expr::new_leaf(3.0, "x");
1183 let operand2 = Expr::new_leaf(4.0, "y");
1184 let mut expr3 = operand1 / operand2;
1185
1186 expr3.learn(1e-09);
1187
1188 let operand1 = expr3.operand1.unwrap();
1189 let operand2 = expr3.operand2.unwrap();
1190 assert_eq!(operand1.grad, 0.25);
1191 assert_eq!(operand2.grad, -0.1875);
1192 }
1193
1194 #[test]
1195 fn test_backpropagation_tanh() {
1196 let operand1 = Expr::new_leaf(0.0, "x");
1197 let mut expr2 = operand1.tanh("tanh(x)");
1198
1199 expr2.learn(1e-09);
1200
1201 let operand1 = expr2.operand1.unwrap();
1202 assert_eq!(operand1.grad, 1.0);
1203 }
1204
1205 #[test]
1206 fn test_backpropagation_relu() {
1207 let operand1 = Expr::new_leaf(-1.0, "x");
1208 let mut expr2 = operand1.relu("relu(x)");
1209
1210 expr2.learn(1e-09);
1211
1212 let operand1 = expr2.operand1.unwrap();
1213 assert_eq!(operand1.grad, 0.0);
1214 }
1215
1216 #[test]
1217 fn test_backpropagation_exp() {
1218 let operand1 = Expr::new_leaf(0.0, "x");
1219 let mut expr2 = operand1.exp("exp(x)");
1220
1221 expr2.learn(1e-09);
1222
1223 let operand1 = expr2.operand1.unwrap();
1224 assert_eq!(operand1.grad, 1.0);
1225 }
1226
1227 #[test]
1228 fn test_backpropagation_pow() {
1229 let operand1 = Expr::new_leaf(2.0, "x");
1230 let operand2 = Expr::new_leaf(3.0, "y");
1231 let mut expr3 = operand1.pow(operand2, "x^y");
1232
1233 expr3.learn(1e-09);
1234
1235 let operand1 = expr3.operand1.unwrap();
1236 let operand2 = expr3.operand2.unwrap();
1237 assert_eq!(operand1.grad, 12.0);
1238 assert_eq!(operand2.grad, 5.545177444479562);
1239 }
1240
1241 #[test]
1242 fn test_backpropagation_mixed_tree() {
1243 let operand1 = Expr::new_leaf(1.0, "x");
1244 let operand2 = Expr::new_leaf(2.0, "y");
1245 let expr3 = operand1 + operand2;
1246 let mut expr4 = expr3.tanh("tanh(x + y)");
1247
1248 expr4.learn(1e-09);
1249
1250 let expr3 = expr4.operand1.unwrap();
1251 let operand1 = expr3.operand1.unwrap();
1252 let operand2 = expr3.operand2.unwrap();
1253
1254 assert_eq!(expr3.grad, 0.009866037165440211);
1255 assert_eq!(operand1.grad, 0.009866037165440211);
1256 assert_eq!(operand2.grad, 0.009866037165440211);
1257 }
1258
1259 #[test]
1260 fn test_backpropagation_karpathys_example() {
1261 let x1 = Expr::new_leaf(2.0, "x1");
1262 let x2 = Expr::new_leaf(0.0, "x2");
1263 let w1 = Expr::new_leaf(-3.0, "w1");
1264 let w2 = Expr::new_leaf(1.0, "w2");
1265 let b = Expr::new_leaf(6.8813735870195432, "b");
1266
1267 let x1w1 = x1 * w1;
1268 let x2w2 = x2 * w2;
1269 let x1w1_x2w2 = x1w1 + x2w2;
1270 let n = x1w1_x2w2 + b;
1271 let mut o = n.tanh("tanh(n)");
1272
1273 o.learn(1e-09);
1274
1275 assert_eq!(o.operation, Operation::Tanh);
1276 assert_eq!(o.grad, 1.0);
1277
1278 let n = o.operand1.unwrap();
1279 assert_eq!(n.operation, Operation::Add);
1280 assert_float_eq(n.grad, 0.5);
1281
1282 let x1w1_x2w2 = n.operand1.unwrap();
1283 assert_eq!(x1w1_x2w2.operation, Operation::Add);
1284 assert_float_eq(x1w1_x2w2.grad, 0.5);
1285
1286 let b = n.operand2.unwrap();
1287 assert_eq!(b.operation, Operation::None);
1288 assert_float_eq(b.grad, 0.5);
1289
1290 let x1w1 = x1w1_x2w2.operand1.unwrap();
1291 assert_eq!(x1w1.operation, Operation::Mul);
1292 assert_float_eq(x1w1.grad, 0.5);
1293
1294 let x2w2 = x1w1_x2w2.operand2.unwrap();
1295 assert_eq!(x2w2.operation, Operation::Mul);
1296 assert_float_eq(x2w2.grad, 0.5);
1297
1298 let x1 = x1w1.operand1.unwrap();
1299 assert_eq!(x1.operation, Operation::None);
1300 assert_float_eq(x1.grad, -1.5);
1301
1302 let w1 = x1w1.operand2.unwrap();
1303 assert_eq!(w1.operation, Operation::None);
1304 assert_float_eq(w1.grad, 1.0);
1305
1306 let x2 = x2w2.operand1.unwrap();
1307 assert_eq!(x2.operation, Operation::None);
1308 assert_float_eq(x2.grad, 0.5);
1309
1310 let w2 = x2w2.operand2.unwrap();
1311 assert_eq!(w2.operation, Operation::None);
1312 assert_float_eq(w2.grad, 0.0);
1313 }
1314
1315 #[test]
1316 fn test_learn_simple() {
1317 let mut expr = Expr::new_leaf(1.0, "x");
1318 expr.learn(1e-01);
1319
1320 assert_float_eq(expr.result, 0.9);
1321 }
1322
1323 #[test]
1324 fn test_learn_skips_non_learnable() {
1325 let mut expr = Expr::new_leaf(1.0, "x");
1326 expr.is_learnable = false;
1327 expr.learn(1e-01);
1328
1329 assert_float_eq(expr.result, 1.0);
1330 }
1331
1332 #[test]
1333 fn test_find_simple() {
1334 let expr = Expr::new_leaf(1.0, "x");
1335 let expr2 = expr.tanh("tanh(x)");
1336
1337 let found = expr2.find("x");
1338 assert!(found.is_some());
1339 assert_eq!(found.unwrap().name, "x");
1340 }
1341
1342 #[test]
1343 fn test_find_not_found() {
1344 let expr = Expr::new_leaf(1.0, "x");
1345 let expr2 = expr.tanh("tanh(x)");
1346
1347 let found = expr2.find("y");
1348 assert!(found.is_none());
1349 }
1350
1351 #[test]
1352 fn test_sum_iterator() {
1353 let expr = Expr::new_leaf(1.0, "x");
1354 let expr2 = Expr::new_leaf(2.0, "y");
1355 let expr3 = Expr::new_leaf(3.0, "z");
1356
1357 let sum: Expr = vec![expr, expr2, expr3].into_iter().sum::<Expr>();
1358 assert_eq!(sum.result, 6.0);
1359 }
1360
1361 #[test]
1362 fn test_find_after_clone() {
1363 let expr = Expr::new_leaf(1.0, "x");
1364 let expr2 = expr.tanh("tanh(x)");
1365 let expr2_clone = expr2.clone();
1366
1367 let found = expr2_clone.find("x");
1368 assert!(found.is_some());
1369 assert_eq!(found.unwrap().name, "x");
1370 }
1371}