1#![deny(missing_docs)]
6use std::ops::{Add, Div, Mul, Sub};
7use std::iter::Sum;
8
9#[derive(Debug, Clone, PartialEq)]
10enum Operation {
11    None,
12    Add,
13    Sub,
14    Mul,
15    Div,
16    Tanh,
17    Exp,
18    Pow,
19    ReLU,
20    Log,
21    Neg,
22}
23
24impl Operation {
25    fn assert_is_type(&self, expr_type: ExprType) {
26        match self {
27            Operation::None => assert_eq!(expr_type, ExprType::Leaf),
28            Operation::Tanh | Operation::Exp | Operation::ReLU | Operation::Log | Operation::Neg => assert_eq!(expr_type, ExprType::Unary),
29            _ => assert_eq!(expr_type, ExprType::Binary),
30        }
31    }
32}
33
34#[derive(Debug, PartialEq)]
35enum ExprType {
36    Leaf,
37    Unary,
38    Binary,
39}
40
41#[derive(Debug, Clone)]
51pub struct Expr {
52    operand1: Option<Box<Expr>>,
53    operand2: Option<Box<Expr>>,
54    operation: Operation,
55    pub result: f64,
57    pub is_learnable: bool,
59    grad: f64,
60    pub name: String,
62}
63
64impl Expr {
65    pub fn new_leaf(value: f64, name: &str) -> Expr {
74        Expr {
75            operand1: None,
76            operand2: None,
77            operation: Operation::None,
78            result: value,
79            is_learnable: true,
80            grad: 0.0,
81            name: name.to_string(),
82        }
83    }
84
85    fn expr_type(&self) -> ExprType {
86        match self.operation {
87            Operation::None => ExprType::Leaf,
88            Operation::Tanh | Operation::Exp | Operation::ReLU | Operation::Log | Operation::Neg => ExprType::Unary,
89            _ => ExprType::Binary,
90        }
91    }
92
93    fn new_unary(operand: Expr, operation: Operation, result: f64, name: &str) -> Expr {
94        operation.assert_is_type(ExprType::Unary);
95        Expr {
96            operand1: Some(Box::new(operand)),
97            operand2: None,
98            operation,
99            result,
100            is_learnable: false,
101            grad: 0.0,
102            name: name.to_string(),
103        }
104    }
105
106    fn new_binary(operand1: Expr, operand2: Expr, operation: Operation, result: f64, name: &str) -> Expr {
107        operation.assert_is_type(ExprType::Binary);
108        Expr {
109            operand1: Some(Box::new(operand1)),
110            operand2: Some(Box::new(operand2)),
111            operation,
112            result,
113            is_learnable: false,
114            grad: 0.0,
115            name: name.to_string(),
116        }
117    }
118
119    pub fn tanh(self, name: &str) -> Expr {
131        let result = self.result.tanh();
132        Expr::new_unary(self, Operation::Tanh, result, name)
133    }
134
135    pub fn relu(self, name: &str) -> Expr {
147        let result = self.result.max(0.0);
148        Expr::new_unary(self, Operation::ReLU, result, name)
149    }
150
151    pub fn exp(self, name: &str) -> Expr {
163        let result = self.result.exp();
164        Expr::new_unary(self, Operation::Exp, result, name)
165    }
166
167    pub fn pow(self, exponent: Expr, name: &str) -> Expr {
180        let result = self.result.powf(exponent.result);
181        Expr::new_binary(self, exponent, Operation::Pow, result, name)
182    }
183
184    pub fn log(self, name: &str) -> Expr {
196        let result = self.result.ln();
197        Expr::new_unary(self, Operation::Log, result, name)
198    }
199
200    pub fn neg(self, name: &str) -> Expr {
212        let result = -self.result;
213        Expr::new_unary(self, Operation::Neg, result, name)
214    }
215
216    pub fn recalculate(&mut self) {
231        match self.expr_type() {
232            ExprType::Leaf => {}
233            ExprType::Unary => {
234                let operand1 = self.operand1.as_mut().expect("Unary expression did not have an operand");
235                operand1.recalculate();
236
237                self.result = match self.operation {
238                    Operation::Tanh => operand1.result.tanh(),
239                    Operation::Exp => operand1.result.exp(),
240                    Operation::ReLU => operand1.result.max(0.0),
241                    Operation::Log => operand1.result.ln(),
242                    Operation::Neg => -operand1.result,
243                    _ => panic!("Invalid unary operation {:?}", self.operation),
244                };
245            }
246            ExprType::Binary => {
247                let operand1 = self.operand1.as_mut().expect("Binary expression did not have an operand");
248                let operand2 = self.operand2.as_mut().expect("Binary expression did not have a second operand");
249
250                operand1.recalculate();
251                operand2.recalculate();
252
253                self.result = match self.operation {
254                    Operation::Add => operand1.result + operand2.result,
255                    Operation::Sub => operand1.result - operand2.result,
256                    Operation::Mul => operand1.result * operand2.result,
257                    Operation::Div => operand1.result / operand2.result,
258                    Operation::Pow => operand1.result.powf(operand2.result),
259                    _ => panic!("Invalid binary operation: {:?}", self.operation),
260                };
261            }
262        }
263    }
264
265    pub fn learn(&mut self, learning_rate: f64) {
287        self.grad = 1.0;
288        self.learn_internal(learning_rate);
289    }
290
291    fn learn_internal(&mut self, learning_rate: f64) {
292        match self.expr_type() {
293            ExprType::Leaf => {
294                if self.is_learnable {
297                    self.result -= learning_rate * self.grad;
298                 }
299            }
300            ExprType::Unary => {
301                let operand1 = self.operand1.as_mut().expect("Unary expression did not have an operand");
302
303                match self.operation {
304                    Operation::Tanh => {
305                        let tanh_grad = 1.0 - (self.result * self.result);
306                        operand1.grad = self.grad * tanh_grad;
307                    }
308                    Operation::Exp => {
309                        operand1.grad = self.grad * self.result;
310                    }
311                    Operation::ReLU => {
312                        operand1.grad = self.grad * if self.result > 0.0 { 1.0 } else { 0.0 };
313                    }
314                    Operation::Log => {
315                        operand1.grad = self.grad / operand1.result;
316                    }
317                    Operation::Neg => {
318                        operand1.grad = -self.grad;
319                    }
320                    _ => panic!("Invalid unary operation {:?}", self.operation),
321                }
322
323                operand1.learn_internal(learning_rate);
324            }
325            ExprType::Binary => {
326                let operand1 = self.operand1.as_mut().expect("Binary expression did not have an operand");
327                let operand2 = self.operand2.as_mut().expect("Binary expression did not have a second operand");
328
329                match self.operation {
330                    Operation::Add => {
331                        operand1.grad = self.grad;
332                        operand2.grad = self.grad;
333                    }
334                    Operation::Sub => {
335                        operand1.grad = self.grad;
336                        operand2.grad = -self.grad;
337                    }
338                    Operation::Mul => {
339                        let operand2_result = operand2.result;
340                        let operand1_result = operand1.result;
341
342                        operand1.grad = self.grad * operand2_result;
343                        operand2.grad = self.grad * operand1_result;
344                    }
345                    Operation::Div => {
346                        let operand2_result = operand2.result;
347                        let operand1_result = operand1.result;
348
349                        operand1.grad = self.grad / operand2_result;
350                        operand2.grad = -self.grad * operand1_result / (operand2_result * operand2_result);
351                    }
352                    Operation::Pow => {
353                        let exponent = operand2.result;
354                        let base = operand1.result;
355
356                        operand1.grad = self.grad * exponent * base.powf(exponent - 1.0);
357                        operand2.grad = self.grad * base.powf(exponent) * base.ln();
358                    }
359                    _ => panic!("Invalid binary operation: {:?}", self.operation),
360                }
361
362                operand1.learn_internal(learning_rate);
363                operand2.learn_internal(learning_rate);
364            }
365        }
366    }
367
368    pub fn find(&self, name: &str) -> Option<&Expr> {
384        if self.name == name {
385            return Some(self);
386        }
387
388        match self.expr_type() {
389            ExprType::Leaf => None,
390            ExprType::Unary => {
391                let operand1 = self.operand1.as_ref().expect("Unary expression did not have an operand");
392                operand1.find(name)
393            }
394            ExprType::Binary => {
395                let operand1 = self.operand1.as_ref().expect("Binary expression did not have an operand");
396                let operand2 = self.operand2.as_ref().expect("Binary expression did not have a second operand");
397
398                let result = operand1.find(name);
399                if result.is_some() {
400                    return result;
401                }
402
403                operand2.find(name)
404            }
405        }
406    }
407}
408
409impl Add for Expr {
425    type Output = Expr;
426
427    fn add(self, other: Expr) -> Expr {
428        let result = self.result + other.result;
429        let name = &format!("({} + {})", self.name, other.name);
430        Expr::new_binary(self, other, Operation::Add, result, name)
431    }
432}
433
434impl Add<f64> for Expr {
448    type Output = Expr;
449
450    fn add(self, other: f64) -> Expr {
451        let operand2 = Expr::new_leaf(other, &other.to_string());
452        self + operand2
453    }
454}
455
456impl Add<Expr> for f64 {
470    type Output = Expr;
471
472    fn add(self, other: Expr) -> Expr {
473        let operand1 = Expr::new_leaf(self, &self.to_string());
474        operand1 + other
475    }
476}
477
478impl Mul for Expr {
495    type Output = Expr;
496
497    fn mul(self, other: Expr) -> Expr {
498        let result = self.result * other.result;
499        let name = &format!("({} * {})", self.name, other.name);
500        Expr::new_binary(self, other, Operation::Mul, result, name)
501    }
502}
503
504impl Mul<f64> for Expr {
519    type Output = Expr;
520
521    fn mul(self, other: f64) -> Expr {
522        let operand2 = Expr::new_leaf(other, &other.to_string());
523        self * operand2
524    }
525}
526
527impl Mul<Expr> for f64 {
542    type Output = Expr;
543
544    fn mul(self, other: Expr) -> Expr {
545        let operand1 = Expr::new_leaf(self, &self.to_string());
546        operand1 * other
547    }
548}
549
550impl Sub for Expr {
567    type Output = Expr;
568
569    fn sub(self, other: Expr) -> Expr {
570        let result = self.result - other.result;
571        let name = &format!("({} - {})", self.name, other.name);
572        Expr::new_binary(self, other, Operation::Sub, result, name)
573    }
574}
575
576impl Sub<f64> for Expr {
591    type Output = Expr;
592
593    fn sub(self, other: f64) -> Expr {
594        let operand2 = Expr::new_leaf(other, &other.to_string());
595        self - operand2
596    }
597}
598
599impl Sub<Expr> for f64 {
614    type Output = Expr;
615
616    fn sub(self, other: Expr) -> Expr {
617        let operand1 = Expr::new_leaf(self, &self.to_string());
618        operand1 - other
619    }
620}
621
622impl Div for Expr {
639    type Output = Expr;
640
641    fn div(self, other: Expr) -> Expr {
642        let result = self.result / other.result;
643        let name = &format!("({} / {})", self.name, other.name);
644        Expr::new_binary(self, other, Operation::Div, result, name)
645    }
646}
647
648impl Div<f64> for Expr {
663    type Output = Expr;
664
665    fn div(self, other: f64) -> Expr {
666        let operand2 = Expr::new_leaf(other, &other.to_string());
667        self / operand2
668    }
669}
670
671impl Sum for Expr {
692    fn sum<I>(iter: I) -> Self
693    where
694        I: Iterator<Item = Self>,
695    {
696        iter.reduce(|acc, x| acc + x)
697            .unwrap_or(Expr::new_leaf(0.0, "0.0"))
698    }
699}
700
701#[cfg(test)]
702mod tests {
703    use super::*;
704
705    fn assert_float_eq(f1: f64, f2: f64) {
706        let places = 7;
707        let tolerance = 10.0_f64.powi(-places);
708        assert!((f1 - f2).abs() < tolerance, "{} != {} (tol: {})", f1, f2, tolerance);
709    }
710
711    #[test]
712    fn test() {
713        let expr = Expr::new_leaf(1.0, "x");
714        assert_eq!(expr.result, 1.0);
715    }
716
717    #[test]
718    fn test_unary() {
719        let expr = Expr::new_leaf(1.0, "x");
720        let expr2 = Expr::new_unary(expr, Operation::Tanh, 1.1, "tanh(x)");
721
722        assert_eq!(expr2.result, 1.1);
723        assert_eq!(expr2.operand1.unwrap().result, 1.0);
724    }
725
726    #[test]
727    #[should_panic]
728    fn test_unary_expression_type_check() {
729        let expr = Expr::new_leaf(1.0, "x");
730        let _expr2 = Expr::new_unary(expr, Operation::Add, 1.1, "tanh(x)");
731    }
732
733    #[test]
734    fn test_binary() {
735        let expr = Expr::new_leaf(1.0, "x");
736        let expr2 = Expr::new_leaf(2.0, "y");
737        let expr3 = Expr::new_binary(expr, expr2, Operation::Add, 1.1, "x + y");
738
739        assert_eq!(expr3.result, 1.1);
740        assert_eq!(expr3.operand1.unwrap().result, 1.0);
741        assert_eq!(expr3.operand2.unwrap().result, 2.0);
742    }
743
744    #[test]
745    #[should_panic]
746    fn test_binary_expression_type_check() {
747        let expr = Expr::new_leaf(1.0, "x");
748        let expr2 = Expr::new_leaf(2.0, "y");
749        let _expr3 = Expr::new_binary(expr, expr2, Operation::Tanh, 3.0, "x + y");
750    }
751
752    #[test]
753    fn test_mixed_tree() {
754        let expr = Expr::new_leaf(1.0, "x");
755        let expr2 = Expr::new_leaf(2.0, "y");
756        let expr3 = Expr::new_binary(expr, expr2, Operation::Sub, 1.1, "x - y");
757        let expr4 = Expr::new_unary(expr3, Operation::Tanh, 1.2, "tanh(x - y)");
758
759        assert_eq!(expr4.result, 1.2);
760        let expr3 = expr4.operand1.unwrap();
761        assert_eq!(expr3.result, 1.1);
762        assert_eq!(expr3.operand1.unwrap().result, 1.0);
763        assert_eq!(expr3.operand2.unwrap().result, 2.0);
764    }
765
766    #[test]
767    fn test_tanh() {
768        let expr = Expr::new_leaf(1.0, "x");
769        let expr2 = expr.tanh("tanh(x)");
770
771        assert_eq!(expr2.result, 0.7615941559557649);
772        assert!(expr2.operand1.is_some());
773        assert_eq!(expr2.operand1.unwrap().result, 1.0);
774        assert_eq!(expr2.operation, Operation::Tanh);
775        assert!(expr2.operand2.is_none());
776
777        fn get_tanh(x: f64) -> f64 {
779            Expr::new_leaf(x, "x").tanh("tanh(x)").result
780        }
781
782        assert_float_eq(get_tanh(10.74), 0.9999999);
783        assert_float_eq(get_tanh(-10.74), -0.9999999);
784        assert_float_eq(get_tanh(0.0), 0.0);
785    }
786
787    #[test]
788    fn test_exp() {
789        let expr = Expr::new_leaf(1.0, "x");
790        let expr2 = expr.exp("exp(x)");
791
792        assert_eq!(expr2.result, 2.718281828459045);
793        assert!(expr2.operand1.is_some());
794        assert_eq!(expr2.operand1.unwrap().result, 1.0);
795        assert_eq!(expr2.operation, Operation::Exp);
796        assert!(expr2.operand2.is_none());
797    }
798
799    #[test]
800    fn test_relu() {
801        let expr = Expr::new_leaf(-1.0, "x");
803        let expr2 = expr.relu("relu(x)");
804
805        assert_eq!(expr2.result, 0.0);
806        assert!(expr2.operand1.is_some());
807        assert_eq!(expr2.operand1.unwrap().result, -1.0);
808        assert_eq!(expr2.operation, Operation::ReLU);
809        assert!(expr2.operand2.is_none());
810
811        let expr = Expr::new_leaf(1.0, "x");
813        let expr2 = expr.relu("relu(x)");
814
815        assert_eq!(expr2.result, 1.0);
816        assert!(expr2.operand1.is_some());
817        assert_eq!(expr2.operand1.unwrap().result, 1.0);
818        assert_eq!(expr2.operation, Operation::ReLU);
819        assert!(expr2.operand2.is_none());
820    }
821
822    #[test]
823    fn test_pow() {
824        let expr = Expr::new_leaf(2.0, "x");
825        let expr2 = Expr::new_leaf(3.0, "y");
826        let result = expr.pow(expr2, "x^y");
827
828        assert_eq!(result.result, 8.0);
829        assert!(result.operand1.is_some());
830        assert_eq!(result.operand1.unwrap().result, 2.0);
831        assert_eq!(result.operation, Operation::Pow);
832        
833        assert!(result.operand2.is_some());
834        assert_eq!(result.operand2.unwrap().result, 3.0);
835    }
836
837    #[test]
838    fn test_add() {
839        let expr = Expr::new_leaf(1.0, "x");
840        let expr2 = Expr::new_leaf(2.0, "y");
841        let expr3 = expr + expr2;
842
843        assert_eq!(expr3.result, 3.0);
844        assert!(expr3.operand1.is_some());
845        assert_eq!(expr3.operand1.unwrap().result, 1.0);
846        assert!(expr3.operand2.is_some());
847        assert_eq!(expr3.operand2.unwrap().result, 2.0);
848        assert_eq!(expr3.operation, Operation::Add);
849        assert_eq!(expr3.name, "(x + y)");
850    }
851
852    #[test]
853    fn test_add_f64() {
854        let expr = Expr::new_leaf(1.0, "x");
855        let expr2 = expr + 2.0;
856
857        assert_eq!(expr2.result, 3.0);
858        assert!(expr2.operand1.is_some());
859        assert_eq!(expr2.operand1.unwrap().result, 1.0);
860        assert!(expr2.operand2.is_some());
861        assert_eq!(expr2.operand2.unwrap().result, 2.0);
862        assert_eq!(expr2.operation, Operation::Add);
863        assert_eq!(expr2.name, "(x + 2)");
864    }
865
866    #[test]
867    fn test_add_f64_expr() {
868        let expr = Expr::new_leaf(1.0, "x");
869        let expr2 = 2.0 + expr;
870
871        assert_eq!(expr2.result, 3.0);
872        assert!(expr2.operand1.is_some());
873        assert_eq!(expr2.operand1.unwrap().result, 2.0);
874        assert!(expr2.operand2.is_some());
875        assert_eq!(expr2.operand2.unwrap().result, 1.0);
876        assert_eq!(expr2.operation, Operation::Add);
877        assert_eq!(expr2.name, "(2 + x)");
878    }
879
880    #[test]
881    fn test_mul() {
882        let expr = Expr::new_leaf(2.0, "x");
883        let expr2 = Expr::new_leaf(3.0, "y");
884        let expr3 = expr * expr2;
885
886        assert_eq!(expr3.result, 6.0);
887        assert!(expr3.operand1.is_some());
888        assert_eq!(expr3.operand1.unwrap().result, 2.0);
889        assert!(expr3.operand2.is_some());
890        assert_eq!(expr3.operand2.unwrap().result, 3.0);
891        assert_eq!(expr3.operation, Operation::Mul);
892        assert_eq!(expr3.name, "(x * y)");
893    }
894
895    #[test]
896    fn test_mul_f64() {
897        let expr = Expr::new_leaf(2.0, "x");
898        let expr2 = expr * 3.0;
899
900        assert_eq!(expr2.result, 6.0);
901        assert!(expr2.operand1.is_some());
902        assert_eq!(expr2.operand1.unwrap().result, 2.0);
903        assert!(expr2.operand2.is_some());
904        assert_eq!(expr2.operand2.unwrap().result, 3.0);
905        assert_eq!(expr2.operation, Operation::Mul);
906        assert_eq!(expr2.name, "(x * 3)");
907    }
908
909    #[test]
910    fn test_mul_f64_expr() {
911        let expr = Expr::new_leaf(2.0, "x");
912        let expr2 = 3.0 * expr;
913
914        assert_eq!(expr2.result, 6.0);
915        assert!(expr2.operand1.is_some());
916        assert_eq!(expr2.operand1.unwrap().result, 3.0);
917        assert!(expr2.operand2.is_some());
918        assert_eq!(expr2.operand2.unwrap().result, 2.0);
919        assert_eq!(expr2.operation, Operation::Mul);
920        assert_eq!(expr2.name, "(3 * x)");
921    }
922
923    #[test]
924    fn test_sub() {
925        let expr = Expr::new_leaf(2.0, "x");
926        let expr2 = Expr::new_leaf(3.0, "y");
927        let expr3 = expr - expr2;
928
929        assert_eq!(expr3.result, -1.0);
930        assert!(expr3.operand1.is_some());
931        assert_eq!(expr3.operand1.unwrap().result, 2.0);
932        assert!(expr3.operand2.is_some());
933        assert_eq!(expr3.operand2.unwrap().result, 3.0);
934        assert_eq!(expr3.operation, Operation::Sub);
935        assert_eq!(expr3.name, "(x - y)");
936    }
937
938    #[test]
939    fn test_sub_f64() {
940        let expr = Expr::new_leaf(2.0, "x");
941        let expr2 = expr - 3.0;
942
943        assert_eq!(expr2.result, -1.0);
944        assert!(expr2.operand1.is_some());
945        assert_eq!(expr2.operand1.unwrap().result, 2.0);
946        assert!(expr2.operand2.is_some());
947        assert_eq!(expr2.operand2.unwrap().result, 3.0);
948        assert_eq!(expr2.operation, Operation::Sub);
949        assert_eq!(expr2.name, "(x - 3)");
950    }
951
952    #[test]
953    fn test_sub_f64_expr() {
954        let expr = Expr::new_leaf(2.0, "x");
955        let expr2 = 3.0 - expr;
956
957        assert_eq!(expr2.result, 1.0);
958        assert!(expr2.operand1.is_some());
959        assert_eq!(expr2.operand1.unwrap().result, 3.0);
960        assert!(expr2.operand2.is_some());
961        assert_eq!(expr2.operand2.unwrap().result, 2.0);
962        assert_eq!(expr2.operation, Operation::Sub);
963        assert_eq!(expr2.name, "(3 - x)");
964    }
965
966    #[test]
967    fn test_div() {
968        let expr = Expr::new_leaf(6.0, "x");
969        let expr2 = Expr::new_leaf(3.0, "y");
970        let expr3 = expr / expr2;
971
972        assert_eq!(expr3.result, 2.0);
973        assert!(expr3.operand1.is_some());
974        assert_eq!(expr3.operand1.unwrap().result, 6.0);
975        assert!(expr3.operand2.is_some());
976        assert_eq!(expr3.operand2.unwrap().result, 3.0);
977        assert_eq!(expr3.operation, Operation::Div);
978        assert_eq!(expr3.name, "(x / y)");
979    }
980
981    #[test]
982    fn test_div_f64() {
983        let expr = Expr::new_leaf(6.0, "x");
984        let expr2 = expr / 3.0;
985
986        assert_eq!(expr2.result, 2.0);
987        assert!(expr2.operand1.is_some());
988        assert_eq!(expr2.operand1.unwrap().result, 6.0);
989        assert!(expr2.operand2.is_some());
990        assert_eq!(expr2.operand2.unwrap().result, 3.0);
991        assert_eq!(expr2.operation, Operation::Div);
992        assert_eq!(expr2.name, "(x / 3)");
993    }
994
995    #[test]
996    fn test_log() {
997        let expr = Expr::new_leaf(2.0, "x");
998        let expr2 = expr.log("log(x)");
999
1000        assert_eq!(expr2.result, 0.6931471805599453);
1001        assert!(expr2.operand1.is_some());
1002        assert_eq!(expr2.operand1.unwrap().result, 2.0);
1003        assert_eq!(expr2.operation, Operation::Log);
1004        assert!(expr2.operand2.is_none());
1005    }
1006
1007    #[test]
1008    fn test_neg() {
1009        let expr = Expr::new_leaf(2.0, "x");
1010        let expr2 = expr.neg("neg(x)");
1011
1012        assert_eq!(expr2.result, -2.0);
1013        assert!(expr2.operand1.is_some());
1014        assert_eq!(expr2.operand1.unwrap().result, 2.0);
1015        assert_eq!(expr2.operation, Operation::Neg);
1016        assert!(expr2.operand2.is_none());
1017    }
1018
1019    #[test]
1020    fn test_backpropagation_add() {
1021        let operand1 = Expr::new_leaf(1.0, "x");
1022        let operand2 = Expr::new_leaf(2.0, "y");
1023        let mut expr3 = operand1 + operand2;
1024
1025        expr3.learn(1e-09);
1026
1027        let operand1 = expr3.operand1.unwrap();
1028        let operand2 = expr3.operand2.unwrap();
1029        assert_eq!(operand1.grad, 1.0);
1030        assert_eq!(operand2.grad, 1.0);
1031    }
1032
1033    #[test]
1034    fn test_backpropagation_sub() {
1035        let operand1 = Expr::new_leaf(1.0, "x");
1036        let operand2 = Expr::new_leaf(2.0, "y");
1037        let mut expr3 = operand1 - operand2;
1038
1039        expr3.learn(1e-09);
1040
1041        let operand1 = expr3.operand1.unwrap();
1042        let operand2 = expr3.operand2.unwrap();
1043        assert_eq!(operand1.grad, 1.0);
1044        assert_eq!(operand2.grad, -1.0);
1045    }
1046
1047    #[test]
1048    fn test_backpropagation_mul() {
1049        let operand1 = Expr::new_leaf(3.0, "x");
1050        let operand2 = Expr::new_leaf(4.0, "y");
1051        let mut expr3 = operand1 * operand2;
1052
1053        expr3.learn(1e-09);
1054
1055        let operand1 = expr3.operand1.unwrap();
1056        let operand2 = expr3.operand2.unwrap();
1057        assert_eq!(operand1.grad, 4.0);
1058        assert_eq!(operand2.grad, 3.0);
1059    }
1060
1061    #[test]
1062    fn test_backpropagation_div() {
1063        let operand1 = Expr::new_leaf(3.0, "x");
1064        let operand2 = Expr::new_leaf(4.0, "y");
1065        let mut expr3 = operand1 / operand2;
1066
1067        expr3.learn(1e-09);
1068
1069        let operand1 = expr3.operand1.unwrap();
1070        let operand2 = expr3.operand2.unwrap();
1071        assert_eq!(operand1.grad, 0.25);
1072        assert_eq!(operand2.grad, -0.1875);
1073    }
1074
1075    #[test]
1076    fn test_backpropagation_tanh() {
1077        let operand1 = Expr::new_leaf(0.0, "x");
1078        let mut expr2 = operand1.tanh("tanh(x)");
1079
1080        expr2.learn(1e-09);
1081
1082        let operand1 = expr2.operand1.unwrap();
1083        assert_eq!(operand1.grad, 1.0);
1084    }
1085
1086    #[test]
1087    fn test_backpropagation_relu() {
1088        let operand1 = Expr::new_leaf(-1.0, "x");
1089        let mut expr2 = operand1.relu("relu(x)");
1090
1091        expr2.learn(1e-09);
1092
1093        let operand1 = expr2.operand1.unwrap();
1094        assert_eq!(operand1.grad, 0.0);
1095    }
1096
1097    #[test]
1098    fn test_backpropagation_exp() {
1099        let operand1 = Expr::new_leaf(0.0, "x");
1100        let mut expr2 = operand1.exp("exp(x)");
1101
1102        expr2.learn(1e-09);
1103
1104        let operand1 = expr2.operand1.unwrap();
1105        assert_eq!(operand1.grad, 1.0);
1106    }
1107
1108    #[test]
1109    fn test_backpropagation_pow() {
1110        let operand1 = Expr::new_leaf(2.0, "x");
1111        let operand2 = Expr::new_leaf(3.0, "y");
1112        let mut expr3 = operand1.pow(operand2, "x^y");
1113
1114        expr3.learn(1e-09);
1115
1116        let operand1 = expr3.operand1.unwrap();
1117        let operand2 = expr3.operand2.unwrap();
1118        assert_eq!(operand1.grad, 12.0);
1119        assert_eq!(operand2.grad, 5.545177444479562);
1120    }
1121
1122    #[test]
1123    fn test_backpropagation_mixed_tree() {
1124        let operand1 = Expr::new_leaf(1.0, "x");
1125        let operand2 = Expr::new_leaf(2.0, "y");
1126        let expr3 = operand1 + operand2;
1127        let mut expr4 = expr3.tanh("tanh(x + y)");
1128
1129        expr4.learn(1e-09);
1130
1131        let expr3 = expr4.operand1.unwrap();
1132        let operand1 = expr3.operand1.unwrap();
1133        let operand2 = expr3.operand2.unwrap();
1134
1135        assert_eq!(expr3.grad, 0.009866037165440211);
1136        assert_eq!(operand1.grad, 0.009866037165440211);
1137        assert_eq!(operand2.grad, 0.009866037165440211);
1138    }
1139
1140    #[test]
1141    fn test_backpropagation_karpathys_example() {
1142        let x1 = Expr::new_leaf(2.0, "x1");
1143        let x2 = Expr::new_leaf(0.0, "x2");
1144        let w1 = Expr::new_leaf(-3.0, "w1");
1145        let w2 = Expr::new_leaf(1.0, "w2");
1146        let b = Expr::new_leaf(6.8813735870195432, "b");
1147
1148        let x1w1 = x1 * w1;
1149        let x2w2 = x2 * w2;
1150        let x1w1_x2w2 = x1w1 + x2w2;
1151        let n = x1w1_x2w2 + b;
1152        let mut o = n.tanh("tanh(n)");
1153
1154        o.learn(1e-09);
1155
1156        assert_eq!(o.operation, Operation::Tanh);
1157        assert_eq!(o.grad, 1.0);
1158
1159        let n = o.operand1.unwrap();
1160        assert_eq!(n.operation, Operation::Add);
1161        assert_float_eq(n.grad, 0.5);
1162
1163        let x1w1_x2w2 = n.operand1.unwrap();
1164        assert_eq!(x1w1_x2w2.operation, Operation::Add);
1165        assert_float_eq(x1w1_x2w2.grad, 0.5);
1166
1167        let b = n.operand2.unwrap();
1168        assert_eq!(b.operation, Operation::None);
1169        assert_float_eq(b.grad, 0.5);
1170
1171        let x1w1 = x1w1_x2w2.operand1.unwrap();
1172        assert_eq!(x1w1.operation, Operation::Mul);
1173        assert_float_eq(x1w1.grad, 0.5);
1174
1175        let x2w2 = x1w1_x2w2.operand2.unwrap();
1176        assert_eq!(x2w2.operation, Operation::Mul);
1177        assert_float_eq(x2w2.grad, 0.5);
1178
1179        let x1 = x1w1.operand1.unwrap();
1180        assert_eq!(x1.operation, Operation::None);
1181        assert_float_eq(x1.grad, -1.5);
1182
1183        let w1 = x1w1.operand2.unwrap();
1184        assert_eq!(w1.operation, Operation::None);
1185        assert_float_eq(w1.grad, 1.0);
1186
1187        let x2 = x2w2.operand1.unwrap();
1188        assert_eq!(x2.operation, Operation::None);
1189        assert_float_eq(x2.grad, 0.5);
1190
1191        let w2 = x2w2.operand2.unwrap();
1192        assert_eq!(w2.operation, Operation::None);
1193        assert_float_eq(w2.grad, 0.0);
1194    }
1195
1196    #[test]
1197    fn test_learn_simple() {
1198        let mut expr = Expr::new_leaf(1.0, "x");
1199        expr.learn(1e-01);
1200
1201        assert_float_eq(expr.result, 0.9);
1202    }
1203
1204    #[test]
1205    fn test_learn_skips_non_learnable() {
1206        let mut expr = Expr::new_leaf(1.0, "x");
1207        expr.is_learnable = false;
1208        expr.learn(1e-01);
1209
1210        assert_float_eq(expr.result, 1.0);
1211    }
1212
1213    #[test]
1214    fn test_find_simple() {
1215        let expr = Expr::new_leaf(1.0, "x");
1216        let expr2 = expr.tanh("tanh(x)");
1217
1218        let found = expr2.find("x");
1219        assert!(found.is_some());
1220        assert_eq!(found.unwrap().name, "x");
1221    }
1222
1223    #[test]
1224    fn test_find_not_found() {
1225        let expr = Expr::new_leaf(1.0, "x");
1226        let expr2 = expr.tanh("tanh(x)");
1227
1228        let found = expr2.find("y");
1229        assert!(found.is_none());
1230    }
1231
1232    #[test]
1233    fn test_sum_iterator() {
1234        let expr = Expr::new_leaf(1.0, "x");
1235        let expr2 = Expr::new_leaf(2.0, "y");
1236        let expr3 = Expr::new_leaf(3.0, "z");
1237
1238        let sum: Expr = vec![expr, expr2, expr3].into_iter().sum::<Expr>();
1239        assert_eq!(sum.result, 6.0);
1240    }
1241
1242    #[test]
1243    fn test_find_after_clone() {
1244        let expr = Expr::new_leaf(1.0, "x");
1245        let expr2 = expr.tanh("tanh(x)");
1246        let expr2_clone = expr2.clone();
1247
1248        let found = expr2_clone.find("x");
1249        assert!(found.is_some());
1250        assert_eq!(found.unwrap().name, "x");
1251    }
1252}