1#![deny(missing_docs)]
6use std::ops::{Add, Div, Mul, Sub};
7use std::iter::Sum;
8
9#[derive(Debug, Clone, PartialEq)]
10enum Operation {
11    None,
12    Add,
13    Sub,
14    Mul,
15    Div,
16    Tanh,
17    Exp,
18    Pow,
19    ReLU,
20}
21
22impl Operation {
23    fn assert_is_type(&self, expr_type: ExprType) {
24        match self {
25            Operation::None => assert_eq!(expr_type, ExprType::Leaf),
26            Operation::Tanh | Operation::Exp | Operation::ReLU => assert_eq!(expr_type, ExprType::Unary),
27            _ => assert_eq!(expr_type, ExprType::Binary),
28        }
29    }
30}
31
32#[derive(Debug, PartialEq)]
33enum ExprType {
34    Leaf,
35    Unary,
36    Binary,
37}
38
39#[derive(Debug, Clone)]
49pub struct Expr {
50    operand1: Option<Box<Expr>>,
51    operand2: Option<Box<Expr>>,
52    operation: Operation,
53    pub result: f64,
55    pub is_learnable: bool,
57    grad: f64,
58    pub name: String,
60}
61
62impl Expr {
63    pub fn new_leaf(value: f64, name: &str) -> Expr {
72        Expr {
73            operand1: None,
74            operand2: None,
75            operation: Operation::None,
76            result: value,
77            is_learnable: true,
78            grad: 0.0,
79            name: name.to_string(),
80        }
81    }
82
83    fn expr_type(&self) -> ExprType {
84        match self.operation {
85            Operation::None => ExprType::Leaf,
86            Operation::Tanh | Operation::Exp | Operation::ReLU => ExprType::Unary,
87            _ => ExprType::Binary,
88        }
89    }
90
91    fn new_unary(operand: Expr, operation: Operation, result: f64, name: &str) -> Expr {
92        operation.assert_is_type(ExprType::Unary);
93        Expr {
94            operand1: Some(Box::new(operand)),
95            operand2: None,
96            operation,
97            result,
98            is_learnable: true,
99            grad: 0.0,
100            name: name.to_string(),
101        }
102    }
103
104    fn new_binary(operand1: Expr, operand2: Expr, operation: Operation, result: f64, name: &str) -> Expr {
105        operation.assert_is_type(ExprType::Binary);
106        Expr {
107            operand1: Some(Box::new(operand1)),
108            operand2: Some(Box::new(operand2)),
109            operation,
110            result,
111            is_learnable: true,
112            grad: 0.0,
113            name: name.to_string(),
114        }
115    }
116
117    pub fn tanh(self, name: &str) -> Expr {
129        let result = self.result.tanh();
130        Expr::new_unary(self, Operation::Tanh, result, name)
131    }
132
133    pub fn relu(self, name: &str) -> Expr {
145        let result = self.result.max(0.0);
146        Expr::new_unary(self, Operation::ReLU, result, name)
147    }
148
149    pub fn exp(self, name: &str) -> Expr {
161        let result = self.result.exp();
162        Expr::new_unary(self, Operation::Exp, result, name)
163    }
164
165    pub fn pow(self, exponent: Expr, name: &str) -> Expr {
178        let result = self.result.powf(exponent.result);
179        Expr::new_binary(self, exponent, Operation::Pow, result, name)
180    }
181
182    pub fn recalculate(&mut self) {
197        match self.expr_type() {
198            ExprType::Leaf => {}
199            ExprType::Unary => {
200                let operand1 = self.operand1.as_mut().expect("Unary expression did not have an operand");
201                operand1.recalculate();
202
203                self.result = match self.operation {
204                    Operation::Tanh => operand1.result.tanh(),
205                    Operation::Exp => operand1.result.exp(),
206                    Operation::ReLU => operand1.result.max(0.0),
207                    _ => panic!("Invalid unary operation {:?}", self.operation),
208                };
209            }
210            ExprType::Binary => {
211                let operand1 = self.operand1.as_mut().expect("Binary expression did not have an operand");
212                let operand2 = self.operand2.as_mut().expect("Binary expression did not have a second operand");
213
214                operand1.recalculate();
215                operand2.recalculate();
216
217                self.result = match self.operation {
218                    Operation::Add => operand1.result + operand2.result,
219                    Operation::Sub => operand1.result - operand2.result,
220                    Operation::Mul => operand1.result * operand2.result,
221                    Operation::Div => operand1.result / operand2.result,
222                    Operation::Pow => operand1.result.powf(operand2.result),
223                    _ => panic!("Invalid binary operation: {:?}", self.operation),
224                };
225            }
226        }
227    }
228
229    pub fn learn(&mut self, learning_rate: f64) {
251        self.grad = 1.0;
252        self.learn_internal(learning_rate);
253    }
254
255    fn learn_internal(&mut self, learning_rate: f64) {
256        match self.expr_type() {
257            ExprType::Leaf => {
258                if self.is_learnable {
261                    self.result -= learning_rate * self.grad;
262                 }
263            }
264            ExprType::Unary => {
265                let operand1 = self.operand1.as_mut().expect("Unary expression did not have an operand");
266
267                match self.operation {
268                    Operation::Tanh => {
269                        let tanh_grad = 1.0 - (self.result * self.result);
270                        operand1.grad = self.grad * tanh_grad;
271                    }
272                    Operation::Exp => {
273                        operand1.grad = self.grad * self.result;
274                    }
275                    Operation::ReLU => {
276                        operand1.grad = self.grad * if self.result > 0.0 { 1.0 } else { 0.0 };
277                    }
278                    _ => panic!("Invalid unary operation {:?}", self.operation),
279                }
280
281                operand1.learn_internal(learning_rate);
282            }
283            ExprType::Binary => {
284                let operand1 = self.operand1.as_mut().expect("Binary expression did not have an operand");
285                let operand2 = self.operand2.as_mut().expect("Binary expression did not have a second operand");
286
287                match self.operation {
288                    Operation::Add => {
289                        operand1.grad = self.grad;
290                        operand2.grad = self.grad;
291                    }
292                    Operation::Sub => {
293                        operand1.grad = self.grad;
294                        operand2.grad = -self.grad;
295                    }
296                    Operation::Mul => {
297                        let operand2_result = operand2.result;
298                        let operand1_result = operand1.result;
299
300                        operand1.grad = self.grad * operand2_result;
301                        operand2.grad = self.grad * operand1_result;
302                    }
303                    Operation::Div => {
304                        let operand2_result = operand2.result;
305                        let operand1_result = operand1.result;
306
307                        operand1.grad = self.grad / operand2_result;
308                        operand2.grad = -self.grad * operand1_result / (operand2_result * operand2_result);
309                    }
310                    Operation::Pow => {
311                        let exponent = operand2.result;
312                        let base = operand1.result;
313
314                        operand1.grad = self.grad * exponent * base.powf(exponent - 1.0);
315                        operand2.grad = self.grad * base.powf(exponent) * base.ln();
316                    }
317                    _ => panic!("Invalid binary operation: {:?}", self.operation),
318                }
319
320                operand1.learn_internal(learning_rate);
321                operand2.learn_internal(learning_rate);
322            }
323        }
324    }
325
326    pub fn find(&self, name: &str) -> Option<&Expr> {
342        if self.name == name {
343            return Some(self);
344        }
345
346        match self.expr_type() {
347            ExprType::Leaf => None,
348            ExprType::Unary => {
349                let operand1 = self.operand1.as_ref().expect("Unary expression did not have an operand");
350                operand1.find(name)
351            }
352            ExprType::Binary => {
353                let operand1 = self.operand1.as_ref().expect("Binary expression did not have an operand");
354                let operand2 = self.operand2.as_ref().expect("Binary expression did not have a second operand");
355
356                let result = operand1.find(name);
357                if result.is_some() {
358                    return result;
359                }
360
361                operand2.find(name)
362            }
363        }
364    }
365}
366
367impl Add for Expr {
383    type Output = Expr;
384
385    fn add(self, other: Expr) -> Expr {
386        let result = self.result + other.result;
387        let name = &format!("({} + {})", self.name, other.name);
388        Expr::new_binary(self, other, Operation::Add, result, name)
389    }
390}
391
392impl Add<f64> for Expr {
406    type Output = Expr;
407
408    fn add(self, other: f64) -> Expr {
409        let operand2 = Expr::new_leaf(other, &other.to_string());
410        self + operand2
411    }
412}
413
414impl Add<Expr> for f64 {
428    type Output = Expr;
429
430    fn add(self, other: Expr) -> Expr {
431        let operand1 = Expr::new_leaf(self, &self.to_string());
432        operand1 + other
433    }
434}
435
436impl Mul for Expr {
453    type Output = Expr;
454
455    fn mul(self, other: Expr) -> Expr {
456        let result = self.result * other.result;
457        let name = &format!("({} * {})", self.name, other.name);
458        Expr::new_binary(self, other, Operation::Mul, result, name)
459    }
460}
461
462impl Mul<f64> for Expr {
477    type Output = Expr;
478
479    fn mul(self, other: f64) -> Expr {
480        let operand2 = Expr::new_leaf(other, &other.to_string());
481        self * operand2
482    }
483}
484
485impl Mul<Expr> for f64 {
500    type Output = Expr;
501
502    fn mul(self, other: Expr) -> Expr {
503        let operand1 = Expr::new_leaf(self, &self.to_string());
504        operand1 * other
505    }
506}
507
508impl Sub for Expr {
525    type Output = Expr;
526
527    fn sub(self, other: Expr) -> Expr {
528        let result = self.result - other.result;
529        let name = &format!("({} - {})", self.name, other.name);
530        Expr::new_binary(self, other, Operation::Sub, result, name)
531    }
532}
533
534impl Sub<f64> for Expr {
549    type Output = Expr;
550
551    fn sub(self, other: f64) -> Expr {
552        let operand2 = Expr::new_leaf(other, &other.to_string());
553        self - operand2
554    }
555}
556
557impl Sub<Expr> for f64 {
572    type Output = Expr;
573
574    fn sub(self, other: Expr) -> Expr {
575        let operand1 = Expr::new_leaf(self, &self.to_string());
576        operand1 - other
577    }
578}
579
580impl Div for Expr {
597    type Output = Expr;
598
599    fn div(self, other: Expr) -> Expr {
600        let result = self.result / other.result;
601        let name = &format!("({} / {})", self.name, other.name);
602        Expr::new_binary(self, other, Operation::Div, result, name)
603    }
604}
605
606impl Div<f64> for Expr {
621    type Output = Expr;
622
623    fn div(self, other: f64) -> Expr {
624        let operand2 = Expr::new_leaf(other, &other.to_string());
625        self / operand2
626    }
627}
628
629impl Sum for Expr {
650    fn sum<I>(iter: I) -> Self
651    where
652        I: Iterator<Item = Self>,
653    {
654        iter.reduce(|acc, x| acc + x)
655            .unwrap_or(Expr::new_leaf(0.0, "0.0"))
656    }
657}
658
659#[cfg(test)]
660mod tests {
661    use super::*;
662
663    fn assert_float_eq(f1: f64, f2: f64) {
664        let places = 7;
665        let tolerance = 10.0_f64.powi(-places);
666        assert!((f1 - f2).abs() < tolerance, "{} != {} (tol: {})", f1, f2, tolerance);
667    }
668
669    #[test]
670    fn test() {
671        let expr = Expr::new_leaf(1.0, "x");
672        assert_eq!(expr.result, 1.0);
673    }
674
675    #[test]
676    fn test_unary() {
677        let expr = Expr::new_leaf(1.0, "x");
678        let expr2 = Expr::new_unary(expr, Operation::Tanh, 1.1, "tanh(x)");
679
680        assert_eq!(expr2.result, 1.1);
681        assert_eq!(expr2.operand1.unwrap().result, 1.0);
682    }
683
684    #[test]
685    #[should_panic]
686    fn test_unary_expression_type_check() {
687        let expr = Expr::new_leaf(1.0, "x");
688        let _expr2 = Expr::new_unary(expr, Operation::Add, 1.1, "tanh(x)");
689    }
690
691    #[test]
692    fn test_binary() {
693        let expr = Expr::new_leaf(1.0, "x");
694        let expr2 = Expr::new_leaf(2.0, "y");
695        let expr3 = Expr::new_binary(expr, expr2, Operation::Add, 1.1, "x + y");
696
697        assert_eq!(expr3.result, 1.1);
698        assert_eq!(expr3.operand1.unwrap().result, 1.0);
699        assert_eq!(expr3.operand2.unwrap().result, 2.0);
700    }
701
702    #[test]
703    #[should_panic]
704    fn test_binary_expression_type_check() {
705        let expr = Expr::new_leaf(1.0, "x");
706        let expr2 = Expr::new_leaf(2.0, "y");
707        let _expr3 = Expr::new_binary(expr, expr2, Operation::Tanh, 3.0, "x + y");
708    }
709
710    #[test]
711    fn test_mixed_tree() {
712        let expr = Expr::new_leaf(1.0, "x");
713        let expr2 = Expr::new_leaf(2.0, "y");
714        let expr3 = Expr::new_binary(expr, expr2, Operation::Sub, 1.1, "x - y");
715        let expr4 = Expr::new_unary(expr3, Operation::Tanh, 1.2, "tanh(x - y)");
716
717        assert_eq!(expr4.result, 1.2);
718        let expr3 = expr4.operand1.unwrap();
719        assert_eq!(expr3.result, 1.1);
720        assert_eq!(expr3.operand1.unwrap().result, 1.0);
721        assert_eq!(expr3.operand2.unwrap().result, 2.0);
722    }
723
724    #[test]
725    fn test_tanh() {
726        let expr = Expr::new_leaf(1.0, "x");
727        let expr2 = expr.tanh("tanh(x)");
728
729        assert_eq!(expr2.result, 0.7615941559557649);
730        assert!(expr2.operand1.is_some());
731        assert_eq!(expr2.operand1.unwrap().result, 1.0);
732        assert_eq!(expr2.operation, Operation::Tanh);
733        assert!(expr2.operand2.is_none());
734
735        fn get_tanh(x: f64) -> f64 {
737            Expr::new_leaf(x, "x").tanh("tanh(x)").result
738        }
739
740        assert_float_eq(get_tanh(10.74), 0.9999999);
741        assert_float_eq(get_tanh(-10.74), -0.9999999);
742        assert_float_eq(get_tanh(0.0), 0.0);
743    }
744
745    #[test]
746    fn test_exp() {
747        let expr = Expr::new_leaf(1.0, "x");
748        let expr2 = expr.exp("exp(x)");
749
750        assert_eq!(expr2.result, 2.718281828459045);
751        assert!(expr2.operand1.is_some());
752        assert_eq!(expr2.operand1.unwrap().result, 1.0);
753        assert_eq!(expr2.operation, Operation::Exp);
754        assert!(expr2.operand2.is_none());
755    }
756
757    #[test]
758    fn test_relu() {
759        let expr = Expr::new_leaf(-1.0, "x");
761        let expr2 = expr.relu("relu(x)");
762
763        assert_eq!(expr2.result, 0.0);
764        assert!(expr2.operand1.is_some());
765        assert_eq!(expr2.operand1.unwrap().result, -1.0);
766        assert_eq!(expr2.operation, Operation::ReLU);
767        assert!(expr2.operand2.is_none());
768
769        let expr = Expr::new_leaf(1.0, "x");
771        let expr2 = expr.relu("relu(x)");
772
773        assert_eq!(expr2.result, 1.0);
774        assert!(expr2.operand1.is_some());
775        assert_eq!(expr2.operand1.unwrap().result, 1.0);
776        assert_eq!(expr2.operation, Operation::ReLU);
777        assert!(expr2.operand2.is_none());
778    }
779
780    #[test]
781    fn test_pow() {
782        let expr = Expr::new_leaf(2.0, "x");
783        let expr2 = Expr::new_leaf(3.0, "y");
784        let result = expr.pow(expr2, "x^y");
785
786        assert_eq!(result.result, 8.0);
787        assert!(result.operand1.is_some());
788        assert_eq!(result.operand1.unwrap().result, 2.0);
789        assert_eq!(result.operation, Operation::Pow);
790        
791        assert!(result.operand2.is_some());
792        assert_eq!(result.operand2.unwrap().result, 3.0);
793    }
794
795    #[test]
796    fn test_add() {
797        let expr = Expr::new_leaf(1.0, "x");
798        let expr2 = Expr::new_leaf(2.0, "y");
799        let expr3 = expr + expr2;
800
801        assert_eq!(expr3.result, 3.0);
802        assert!(expr3.operand1.is_some());
803        assert_eq!(expr3.operand1.unwrap().result, 1.0);
804        assert!(expr3.operand2.is_some());
805        assert_eq!(expr3.operand2.unwrap().result, 2.0);
806        assert_eq!(expr3.operation, Operation::Add);
807        assert_eq!(expr3.name, "(x + y)");
808    }
809
810    #[test]
811    fn test_add_f64() {
812        let expr = Expr::new_leaf(1.0, "x");
813        let expr2 = expr + 2.0;
814
815        assert_eq!(expr2.result, 3.0);
816        assert!(expr2.operand1.is_some());
817        assert_eq!(expr2.operand1.unwrap().result, 1.0);
818        assert!(expr2.operand2.is_some());
819        assert_eq!(expr2.operand2.unwrap().result, 2.0);
820        assert_eq!(expr2.operation, Operation::Add);
821        assert_eq!(expr2.name, "(x + 2)");
822    }
823
824    #[test]
825    fn test_add_f64_expr() {
826        let expr = Expr::new_leaf(1.0, "x");
827        let expr2 = 2.0 + expr;
828
829        assert_eq!(expr2.result, 3.0);
830        assert!(expr2.operand1.is_some());
831        assert_eq!(expr2.operand1.unwrap().result, 2.0);
832        assert!(expr2.operand2.is_some());
833        assert_eq!(expr2.operand2.unwrap().result, 1.0);
834        assert_eq!(expr2.operation, Operation::Add);
835        assert_eq!(expr2.name, "(2 + x)");
836    }
837
838    #[test]
839    fn test_mul() {
840        let expr = Expr::new_leaf(2.0, "x");
841        let expr2 = Expr::new_leaf(3.0, "y");
842        let expr3 = expr * expr2;
843
844        assert_eq!(expr3.result, 6.0);
845        assert!(expr3.operand1.is_some());
846        assert_eq!(expr3.operand1.unwrap().result, 2.0);
847        assert!(expr3.operand2.is_some());
848        assert_eq!(expr3.operand2.unwrap().result, 3.0);
849        assert_eq!(expr3.operation, Operation::Mul);
850        assert_eq!(expr3.name, "(x * y)");
851    }
852
853    #[test]
854    fn test_mul_f64() {
855        let expr = Expr::new_leaf(2.0, "x");
856        let expr2 = expr * 3.0;
857
858        assert_eq!(expr2.result, 6.0);
859        assert!(expr2.operand1.is_some());
860        assert_eq!(expr2.operand1.unwrap().result, 2.0);
861        assert!(expr2.operand2.is_some());
862        assert_eq!(expr2.operand2.unwrap().result, 3.0);
863        assert_eq!(expr2.operation, Operation::Mul);
864        assert_eq!(expr2.name, "(x * 3)");
865    }
866
867    #[test]
868    fn test_mul_f64_expr() {
869        let expr = Expr::new_leaf(2.0, "x");
870        let expr2 = 3.0 * expr;
871
872        assert_eq!(expr2.result, 6.0);
873        assert!(expr2.operand1.is_some());
874        assert_eq!(expr2.operand1.unwrap().result, 3.0);
875        assert!(expr2.operand2.is_some());
876        assert_eq!(expr2.operand2.unwrap().result, 2.0);
877        assert_eq!(expr2.operation, Operation::Mul);
878        assert_eq!(expr2.name, "(3 * x)");
879    }
880
881    #[test]
882    fn test_sub() {
883        let expr = Expr::new_leaf(2.0, "x");
884        let expr2 = Expr::new_leaf(3.0, "y");
885        let expr3 = expr - expr2;
886
887        assert_eq!(expr3.result, -1.0);
888        assert!(expr3.operand1.is_some());
889        assert_eq!(expr3.operand1.unwrap().result, 2.0);
890        assert!(expr3.operand2.is_some());
891        assert_eq!(expr3.operand2.unwrap().result, 3.0);
892        assert_eq!(expr3.operation, Operation::Sub);
893        assert_eq!(expr3.name, "(x - y)");
894    }
895
896    #[test]
897    fn test_sub_f64() {
898        let expr = Expr::new_leaf(2.0, "x");
899        let expr2 = expr - 3.0;
900
901        assert_eq!(expr2.result, -1.0);
902        assert!(expr2.operand1.is_some());
903        assert_eq!(expr2.operand1.unwrap().result, 2.0);
904        assert!(expr2.operand2.is_some());
905        assert_eq!(expr2.operand2.unwrap().result, 3.0);
906        assert_eq!(expr2.operation, Operation::Sub);
907        assert_eq!(expr2.name, "(x - 3)");
908    }
909
910    #[test]
911    fn test_sub_f64_expr() {
912        let expr = Expr::new_leaf(2.0, "x");
913        let expr2 = 3.0 - expr;
914
915        assert_eq!(expr2.result, 1.0);
916        assert!(expr2.operand1.is_some());
917        assert_eq!(expr2.operand1.unwrap().result, 3.0);
918        assert!(expr2.operand2.is_some());
919        assert_eq!(expr2.operand2.unwrap().result, 2.0);
920        assert_eq!(expr2.operation, Operation::Sub);
921        assert_eq!(expr2.name, "(3 - x)");
922    }
923
924    #[test]
925    fn test_div() {
926        let expr = Expr::new_leaf(6.0, "x");
927        let expr2 = Expr::new_leaf(3.0, "y");
928        let expr3 = expr / expr2;
929
930        assert_eq!(expr3.result, 2.0);
931        assert!(expr3.operand1.is_some());
932        assert_eq!(expr3.operand1.unwrap().result, 6.0);
933        assert!(expr3.operand2.is_some());
934        assert_eq!(expr3.operand2.unwrap().result, 3.0);
935        assert_eq!(expr3.operation, Operation::Div);
936        assert_eq!(expr3.name, "(x / y)");
937    }
938
939    #[test]
940    fn test_div_f64() {
941        let expr = Expr::new_leaf(6.0, "x");
942        let expr2 = expr / 3.0;
943
944        assert_eq!(expr2.result, 2.0);
945        assert!(expr2.operand1.is_some());
946        assert_eq!(expr2.operand1.unwrap().result, 6.0);
947        assert!(expr2.operand2.is_some());
948        assert_eq!(expr2.operand2.unwrap().result, 3.0);
949        assert_eq!(expr2.operation, Operation::Div);
950        assert_eq!(expr2.name, "(x / 3)");
951    }
952
953    #[test]
954    fn test_backpropagation_add() {
955        let operand1 = Expr::new_leaf(1.0, "x");
956        let operand2 = Expr::new_leaf(2.0, "y");
957        let mut expr3 = operand1 + operand2;
958
959        expr3.learn(1e-09);
960
961        let operand1 = expr3.operand1.unwrap();
962        let operand2 = expr3.operand2.unwrap();
963        assert_eq!(operand1.grad, 1.0);
964        assert_eq!(operand2.grad, 1.0);
965    }
966
967    #[test]
968    fn test_backpropagation_sub() {
969        let operand1 = Expr::new_leaf(1.0, "x");
970        let operand2 = Expr::new_leaf(2.0, "y");
971        let mut expr3 = operand1 - operand2;
972
973        expr3.learn(1e-09);
974
975        let operand1 = expr3.operand1.unwrap();
976        let operand2 = expr3.operand2.unwrap();
977        assert_eq!(operand1.grad, 1.0);
978        assert_eq!(operand2.grad, -1.0);
979    }
980
981    #[test]
982    fn test_backpropagation_mul() {
983        let operand1 = Expr::new_leaf(3.0, "x");
984        let operand2 = Expr::new_leaf(4.0, "y");
985        let mut expr3 = operand1 * operand2;
986
987        expr3.learn(1e-09);
988
989        let operand1 = expr3.operand1.unwrap();
990        let operand2 = expr3.operand2.unwrap();
991        assert_eq!(operand1.grad, 4.0);
992        assert_eq!(operand2.grad, 3.0);
993    }
994
995    #[test]
996    fn test_backpropagation_div() {
997        let operand1 = Expr::new_leaf(3.0, "x");
998        let operand2 = Expr::new_leaf(4.0, "y");
999        let mut expr3 = operand1 / operand2;
1000
1001        expr3.learn(1e-09);
1002
1003        let operand1 = expr3.operand1.unwrap();
1004        let operand2 = expr3.operand2.unwrap();
1005        assert_eq!(operand1.grad, 0.25);
1006        assert_eq!(operand2.grad, -0.1875);
1007    }
1008
1009    #[test]
1010    fn test_backpropagation_tanh() {
1011        let operand1 = Expr::new_leaf(0.0, "x");
1012        let mut expr2 = operand1.tanh("tanh(x)");
1013
1014        expr2.learn(1e-09);
1015
1016        let operand1 = expr2.operand1.unwrap();
1017        assert_eq!(operand1.grad, 1.0);
1018    }
1019
1020    #[test]
1021    fn test_backpropagation_relu() {
1022        let operand1 = Expr::new_leaf(-1.0, "x");
1023        let mut expr2 = operand1.relu("relu(x)");
1024
1025        expr2.learn(1e-09);
1026
1027        let operand1 = expr2.operand1.unwrap();
1028        assert_eq!(operand1.grad, 0.0);
1029    }
1030
1031    #[test]
1032    fn test_backpropagation_exp() {
1033        let operand1 = Expr::new_leaf(0.0, "x");
1034        let mut expr2 = operand1.exp("exp(x)");
1035
1036        expr2.learn(1e-09);
1037
1038        let operand1 = expr2.operand1.unwrap();
1039        assert_eq!(operand1.grad, 1.0);
1040    }
1041
1042    #[test]
1043    fn test_backpropagation_pow() {
1044        let operand1 = Expr::new_leaf(2.0, "x");
1045        let operand2 = Expr::new_leaf(3.0, "y");
1046        let mut expr3 = operand1.pow(operand2, "x^y");
1047
1048        expr3.learn(1e-09);
1049
1050        let operand1 = expr3.operand1.unwrap();
1051        let operand2 = expr3.operand2.unwrap();
1052        assert_eq!(operand1.grad, 12.0);
1053        assert_eq!(operand2.grad, 5.545177444479562);
1054    }
1055
1056    #[test]
1057    fn test_backpropagation_mixed_tree() {
1058        let operand1 = Expr::new_leaf(1.0, "x");
1059        let operand2 = Expr::new_leaf(2.0, "y");
1060        let expr3 = operand1 + operand2;
1061        let mut expr4 = expr3.tanh("tanh(x + y)");
1062
1063        expr4.learn(1e-09);
1064
1065        let expr3 = expr4.operand1.unwrap();
1066        let operand1 = expr3.operand1.unwrap();
1067        let operand2 = expr3.operand2.unwrap();
1068
1069        assert_eq!(expr3.grad, 0.009866037165440211);
1070        assert_eq!(operand1.grad, 0.009866037165440211);
1071        assert_eq!(operand2.grad, 0.009866037165440211);
1072    }
1073
1074    #[test]
1075    fn test_backpropagation_karpathys_example() {
1076        let x1 = Expr::new_leaf(2.0, "x1");
1077        let x2 = Expr::new_leaf(0.0, "x2");
1078        let w1 = Expr::new_leaf(-3.0, "w1");
1079        let w2 = Expr::new_leaf(1.0, "w2");
1080        let b = Expr::new_leaf(6.8813735870195432, "b");
1081
1082        let x1w1 = x1 * w1;
1083        let x2w2 = x2 * w2;
1084        let x1w1_x2w2 = x1w1 + x2w2;
1085        let n = x1w1_x2w2 + b;
1086        let mut o = n.tanh("tanh(n)");
1087
1088        o.learn(1e-09);
1089
1090        assert_eq!(o.operation, Operation::Tanh);
1091        assert_eq!(o.grad, 1.0);
1092
1093        let n = o.operand1.unwrap();
1094        assert_eq!(n.operation, Operation::Add);
1095        assert_float_eq(n.grad, 0.5);
1096
1097        let x1w1_x2w2 = n.operand1.unwrap();
1098        assert_eq!(x1w1_x2w2.operation, Operation::Add);
1099        assert_float_eq(x1w1_x2w2.grad, 0.5);
1100
1101        let b = n.operand2.unwrap();
1102        assert_eq!(b.operation, Operation::None);
1103        assert_float_eq(b.grad, 0.5);
1104
1105        let x1w1 = x1w1_x2w2.operand1.unwrap();
1106        assert_eq!(x1w1.operation, Operation::Mul);
1107        assert_float_eq(x1w1.grad, 0.5);
1108
1109        let x2w2 = x1w1_x2w2.operand2.unwrap();
1110        assert_eq!(x2w2.operation, Operation::Mul);
1111        assert_float_eq(x2w2.grad, 0.5);
1112
1113        let x1 = x1w1.operand1.unwrap();
1114        assert_eq!(x1.operation, Operation::None);
1115        assert_float_eq(x1.grad, -1.5);
1116
1117        let w1 = x1w1.operand2.unwrap();
1118        assert_eq!(w1.operation, Operation::None);
1119        assert_float_eq(w1.grad, 1.0);
1120
1121        let x2 = x2w2.operand1.unwrap();
1122        assert_eq!(x2.operation, Operation::None);
1123        assert_float_eq(x2.grad, 0.5);
1124
1125        let w2 = x2w2.operand2.unwrap();
1126        assert_eq!(w2.operation, Operation::None);
1127        assert_float_eq(w2.grad, 0.0);
1128    }
1129
1130    #[test]
1131    fn test_learn_simple() {
1132        let mut expr = Expr::new_leaf(1.0, "x");
1133        expr.learn(1e-01);
1134
1135        assert_float_eq(expr.result, 0.9);
1136    }
1137
1138    #[test]
1139    fn test_learn_skips_non_learnable() {
1140        let mut expr = Expr::new_leaf(1.0, "x");
1141        expr.is_learnable = false;
1142        expr.learn(1e-01);
1143
1144        assert_float_eq(expr.result, 1.0);
1145    }
1146
1147    #[test]
1148    fn test_find_simple() {
1149        let expr = Expr::new_leaf(1.0, "x");
1150        let expr2 = expr.tanh("tanh(x)");
1151
1152        let found = expr2.find("x");
1153        assert!(found.is_some());
1154        assert_eq!(found.unwrap().name, "x");
1155    }
1156
1157    #[test]
1158    fn test_find_not_found() {
1159        let expr = Expr::new_leaf(1.0, "x");
1160        let expr2 = expr.tanh("tanh(x)");
1161
1162        let found = expr2.find("y");
1163        assert!(found.is_none());
1164    }
1165
1166    #[test]
1167    fn test_sum_iterator() {
1168        let expr = Expr::new_leaf(1.0, "x");
1169        let expr2 = Expr::new_leaf(2.0, "y");
1170        let expr3 = Expr::new_leaf(3.0, "z");
1171
1172        let sum: Expr = vec![expr, expr2, expr3].into_iter().sum::<Expr>();
1173        assert_eq!(sum.result, 6.0);
1174    }
1175
1176    #[test]
1177    fn test_find_after_clone() {
1178        let expr = Expr::new_leaf(1.0, "x");
1179        let expr2 = expr.tanh("tanh(x)");
1180        let expr2_clone = expr2.clone();
1181
1182        let found = expr2_clone.find("x");
1183        assert!(found.is_some());
1184        assert_eq!(found.unwrap().name, "x");
1185    }
1186}