alpha_micrograd_rust/
value.rs

1//! A simple library for creating and backpropagating through expression trees.
2//! 
3//! This package includes the following elements to construct expression trees:
4//! - [`Expr`]: a node in the expression tree
5#![deny(missing_docs)]
6use std::ops::{Add, Div, Mul, Sub};
7use std::iter::Sum;
8
9#[derive(Debug, Clone, PartialEq)]
10enum Operation {
11    None,
12    Add,
13    Sub,
14    Mul,
15    Div,
16    Tanh,
17    Exp,
18    Pow,
19    ReLU,
20}
21
22impl Operation {
23    fn assert_is_type(&self, expr_type: ExprType) {
24        match self {
25            Operation::None => assert_eq!(expr_type, ExprType::Leaf),
26            Operation::Tanh | Operation::Exp | Operation::ReLU => assert_eq!(expr_type, ExprType::Unary),
27            _ => assert_eq!(expr_type, ExprType::Binary),
28        }
29    }
30}
31
32#[derive(Debug, PartialEq)]
33enum ExprType {
34    Leaf,
35    Unary,
36    Binary,
37}
38
39/// Expression representing a node in a calculation graph.
40/// 
41/// This struct represents a node in a calculation graph. It can be a leaf node, a unary operation or a binary operation.
42/// 
43/// A leaf node holds a value, which is the one that is used in the calculation.
44/// 
45/// A unary expression is the result of applying a unary operation to another expression. For example, the result of applying the `tanh` operation to a leaf node.
46/// 
47/// A binary expression is the result of applying a binary operation to two other expressions. For example, the result of adding two leaf nodes.
48#[derive(Debug, Clone)]
49pub struct Expr {
50    operand1: Option<Box<Expr>>,
51    operand2: Option<Box<Expr>>,
52    operation: Operation,
53    /// The numeric result of the expression, as result of applying the operation to the operands.
54    pub result: f64,
55    /// Whether the expression is learnable or not. Only learnable [`Expr`] will have their values updated during backpropagation (learning).
56    pub is_learnable: bool,
57    grad: f64,
58    /// The name of the expression, used to identify it in the calculation graph.
59    pub name: String,
60}
61
62impl Expr {
63    /// Creates a new leaf expression with the given value.
64    /// 
65    /// Example:
66    /// ```rust
67    /// use alpha_micrograd_rust::value::Expr;
68    /// 
69    /// let expr = Expr::new_leaf(1.0, "x");
70    /// ```
71    pub fn new_leaf(value: f64, name: &str) -> Expr {
72        Expr {
73            operand1: None,
74            operand2: None,
75            operation: Operation::None,
76            result: value,
77            is_learnable: true,
78            grad: 0.0,
79            name: name.to_string(),
80        }
81    }
82
83    fn expr_type(&self) -> ExprType {
84        match self.operation {
85            Operation::None => ExprType::Leaf,
86            Operation::Tanh | Operation::Exp | Operation::ReLU => ExprType::Unary,
87            _ => ExprType::Binary,
88        }
89    }
90
91    fn new_unary(operand: Expr, operation: Operation, result: f64, name: &str) -> Expr {
92        operation.assert_is_type(ExprType::Unary);
93        Expr {
94            operand1: Some(Box::new(operand)),
95            operand2: None,
96            operation,
97            result,
98            is_learnable: true,
99            grad: 0.0,
100            name: name.to_string(),
101        }
102    }
103
104    fn new_binary(operand1: Expr, operand2: Expr, operation: Operation, result: f64, name: &str) -> Expr {
105        operation.assert_is_type(ExprType::Binary);
106        Expr {
107            operand1: Some(Box::new(operand1)),
108            operand2: Some(Box::new(operand2)),
109            operation,
110            result,
111            is_learnable: true,
112            grad: 0.0,
113            name: name.to_string(),
114        }
115    }
116
117    /// Applies the hyperbolic tangent function to the expression and returns it as a new expression.
118    /// 
119    /// Example:
120    /// ```rust
121    /// use alpha_micrograd_rust::value::Expr;
122    /// 
123    /// let expr = Expr::new_leaf(1.0, "x");
124    /// let expr2 = expr.tanh("tanh");
125    /// 
126    /// println!("Result: {}", expr2.result); // 0.7615941559557649
127    /// ```
128    pub fn tanh(self, name: &str) -> Expr {
129        let result = self.result.tanh();
130        Expr::new_unary(self, Operation::Tanh, result, name)
131    }
132
133    /// Applies the rectified linear unit function to the expression and returns it as a new expression.
134    /// 
135    /// Example:
136    /// ```rust
137    /// use alpha_micrograd_rust::value::Expr;
138    /// 
139    /// let expr = Expr::new_leaf(-1.0, "x");
140    /// let expr2 = expr.relu("relu");
141    /// 
142    /// println!("Result: {}", expr2.result); // 0.0
143    /// ```
144    pub fn relu(self, name: &str) -> Expr {
145        let result = self.result.max(0.0);
146        Expr::new_unary(self, Operation::ReLU, result, name)
147    }
148
149    /// Applies the exponential function (e^x) to the expression and returns it as a new expression.
150    /// 
151    /// Example:
152    /// ```rust
153    /// use alpha_micrograd_rust::value::Expr;
154    /// 
155    /// let expr = Expr::new_leaf(1.0, "x");
156    /// let expr2 = expr.exp("exp");
157    /// 
158    /// println!("Result: {}", expr2.result); // 2.718281828459045
159    /// ```
160    pub fn exp(self, name: &str) -> Expr {
161        let result = self.result.exp();
162        Expr::new_unary(self, Operation::Exp, result, name)
163    }
164
165    /// Raises the expression to the power of the given exponent (expression) and returns it as a new expression.
166    /// 
167    /// Example:
168    /// ```rust
169    /// use alpha_micrograd_rust::value::Expr;
170    /// 
171    /// let expr = Expr::new_leaf(2.0, "x");
172    /// let exponent = Expr::new_leaf(3.0, "y");
173    /// let result = expr.pow(exponent, "x^y");
174    /// 
175    /// println!("Result: {}", result.result); // 8.0
176    /// ```
177    pub fn pow(self, exponent: Expr, name: &str) -> Expr {
178        let result = self.result.powf(exponent.result);
179        Expr::new_binary(self, exponent, Operation::Pow, result, name)
180    }
181
182    /// Recalculates the value of the expression recursively, from new values of the operands.
183    /// 
184    /// Usually will be used after a call to [`Expr::learn`], where the gradients have been calculated and
185    /// the internal values of the expression tree have been updated.
186    /// 
187    /// Example:
188    /// ```rust
189    /// use alpha_micrograd_rust::value::Expr;
190    /// 
191    /// let expr = Expr::new_leaf(1.0, "x");
192    /// let mut expr2 = expr.tanh("tanh(x)");
193    /// expr2.learn(1e-09);
194    /// expr2.recalculate();
195    /// ```
196    pub fn recalculate(&mut self) {
197        match self.expr_type() {
198            ExprType::Leaf => {}
199            ExprType::Unary => {
200                let operand1 = self.operand1.as_mut().expect("Unary expression did not have an operand");
201                operand1.recalculate();
202
203                self.result = match self.operation {
204                    Operation::Tanh => operand1.result.tanh(),
205                    Operation::Exp => operand1.result.exp(),
206                    Operation::ReLU => operand1.result.max(0.0),
207                    _ => panic!("Invalid unary operation {:?}", self.operation),
208                };
209            }
210            ExprType::Binary => {
211                let operand1 = self.operand1.as_mut().expect("Binary expression did not have an operand");
212                let operand2 = self.operand2.as_mut().expect("Binary expression did not have a second operand");
213
214                operand1.recalculate();
215                operand2.recalculate();
216
217                self.result = match self.operation {
218                    Operation::Add => operand1.result + operand2.result,
219                    Operation::Sub => operand1.result - operand2.result,
220                    Operation::Mul => operand1.result * operand2.result,
221                    Operation::Div => operand1.result / operand2.result,
222                    Operation::Pow => operand1.result.powf(operand2.result),
223                    _ => panic!("Invalid binary operation: {:?}", self.operation),
224                };
225            }
226        }
227    }
228
229    /// Applies backpropagation to the expression, updating the values of the
230    /// gradients and the expression itself.
231    /// 
232    /// This method will change the gradients based on the gradient of the last
233    /// expression in the calculation graph.
234    /// 
235    /// Example:
236    /// 
237    /// ```rust
238    /// use alpha_micrograd_rust::value::Expr;
239    /// 
240    /// let expr = Expr::new_leaf(1.0, "x");
241    /// let mut expr2 = expr.tanh("tanh(x)");
242    /// expr2.learn(1e-09);
243    /// ```
244    /// 
245    /// After adjusting the gradients, the method will update the values of the
246    /// individual expression tree nodes to minimize the loss function.
247    /// 
248    /// In order to get a new calculation of the expression tree, you'll need to call
249    /// [`Expr::recalculate`] after calling [`Expr::learn`].
250    pub fn learn(&mut self, learning_rate: f64) {
251        self.grad = 1.0;
252        self.learn_internal(learning_rate);
253    }
254
255    fn learn_internal(&mut self, learning_rate: f64) {
256        match self.expr_type() {
257            ExprType::Leaf => {
258                // leaves have their gradient set externally by other nodes in the tree
259                // leaves can be learnable, in which case we update the value
260                if self.is_learnable {
261                    self.result -= learning_rate * self.grad;
262                 }
263            }
264            ExprType::Unary => {
265                let operand1 = self.operand1.as_mut().expect("Unary expression did not have an operand");
266
267                match self.operation {
268                    Operation::Tanh => {
269                        let tanh_grad = 1.0 - (self.result * self.result);
270                        operand1.grad = self.grad * tanh_grad;
271                    }
272                    Operation::Exp => {
273                        operand1.grad = self.grad * self.result;
274                    }
275                    Operation::ReLU => {
276                        operand1.grad = self.grad * if self.result > 0.0 { 1.0 } else { 0.0 };
277                    }
278                    _ => panic!("Invalid unary operation {:?}", self.operation),
279                }
280
281                operand1.learn_internal(learning_rate);
282            }
283            ExprType::Binary => {
284                let operand1 = self.operand1.as_mut().expect("Binary expression did not have an operand");
285                let operand2 = self.operand2.as_mut().expect("Binary expression did not have a second operand");
286
287                match self.operation {
288                    Operation::Add => {
289                        operand1.grad = self.grad;
290                        operand2.grad = self.grad;
291                    }
292                    Operation::Sub => {
293                        operand1.grad = self.grad;
294                        operand2.grad = -self.grad;
295                    }
296                    Operation::Mul => {
297                        let operand2_result = operand2.result;
298                        let operand1_result = operand1.result;
299
300                        operand1.grad = self.grad * operand2_result;
301                        operand2.grad = self.grad * operand1_result;
302                    }
303                    Operation::Div => {
304                        let operand2_result = operand2.result;
305                        let operand1_result = operand1.result;
306
307                        operand1.grad = self.grad / operand2_result;
308                        operand2.grad = -self.grad * operand1_result / (operand2_result * operand2_result);
309                    }
310                    Operation::Pow => {
311                        let exponent = operand2.result;
312                        let base = operand1.result;
313
314                        operand1.grad = self.grad * exponent * base.powf(exponent - 1.0);
315                        operand2.grad = self.grad * base.powf(exponent) * base.ln();
316                    }
317                    _ => panic!("Invalid binary operation: {:?}", self.operation),
318                }
319
320                operand1.learn_internal(learning_rate);
321                operand2.learn_internal(learning_rate);
322            }
323        }
324    }
325
326    /// Finds a node in the expression tree by its name.
327    /// 
328    /// This method will search the expression tree for a node with the given name.
329    /// If the node is not found, it will return [None].
330    /// 
331    /// Example:
332    /// ```rust
333    /// use alpha_micrograd_rust::value::Expr;
334    /// 
335    /// let expr = Expr::new_leaf(1.0, "x");
336    /// let expr2 = expr.tanh("tanh(x)");
337    /// let original = expr2.find("x");
338    /// 
339    /// assert_eq!(original.expect("Could not find x").result, 1.0);
340    /// ```
341    pub fn find(&self, name: &str) -> Option<&Expr> {
342        if self.name == name {
343            return Some(self);
344        }
345
346        match self.expr_type() {
347            ExprType::Leaf => None,
348            ExprType::Unary => {
349                let operand1 = self.operand1.as_ref().expect("Unary expression did not have an operand");
350                operand1.find(name)
351            }
352            ExprType::Binary => {
353                let operand1 = self.operand1.as_ref().expect("Binary expression did not have an operand");
354                let operand2 = self.operand2.as_ref().expect("Binary expression did not have a second operand");
355
356                let result = operand1.find(name);
357                if result.is_some() {
358                    return result;
359                }
360
361                operand2.find(name)
362            }
363        }
364    }
365}
366
367/// Implements the [`Add`] trait for the [`Expr`] struct.
368/// 
369/// This implementation allows the addition of two [`Expr`] objects.
370/// 
371/// Example:
372/// ```rust
373/// use alpha_micrograd_rust::value::Expr;
374/// 
375/// let expr = Expr::new_leaf(1.0, "x");
376/// let expr2 = Expr::new_leaf(2.0, "y");
377/// 
378/// let result = expr + expr2;
379/// 
380/// println!("Result: {}", result.result); // 3.0
381/// ```
382impl Add for Expr {
383    type Output = Expr;
384
385    fn add(self, other: Expr) -> Expr {
386        let result = self.result + other.result;
387        let name = &format!("({} + {})", self.name, other.name);
388        Expr::new_binary(self, other, Operation::Add, result, name)
389    }
390}
391
392/// Implements the [`Add`] trait for the [`Expr`] struct, when the right operand is a [`f64`].
393/// 
394/// This implementation allows the addition of an [`Expr`] object and a [`f64`] value.
395/// 
396/// Example:
397/// ```rust
398/// use alpha_micrograd_rust::value::Expr;
399/// 
400/// let expr = Expr::new_leaf(1.0, "x");
401/// let result = expr + 2.0;
402/// 
403/// println!("Result: {}", result.result); // 3.0
404/// ```
405impl Add<f64> for Expr {
406    type Output = Expr;
407
408    fn add(self, other: f64) -> Expr {
409        let operand2 = Expr::new_leaf(other, &other.to_string());
410        self + operand2
411    }
412}
413
414/// Implements the [`Add`] trait for the [`f64`] type, when the right operand is an [`Expr`].
415/// 
416/// This implementation allows the addition of a [`f64`] value and an [`Expr`] object.
417/// 
418/// Example:
419/// ```rust
420/// use alpha_micrograd_rust::value::Expr;
421/// 
422/// let expr = Expr::new_leaf(1.0, "x");
423/// let result = 2.0 + expr;
424/// 
425/// println!("Result: {}", result.result); // 3.0
426/// ```
427impl Add<Expr> for f64 {
428    type Output = Expr;
429
430    fn add(self, other: Expr) -> Expr {
431        let operand1 = Expr::new_leaf(self, &self.to_string());
432        operand1 + other
433    }
434}
435
436/// Implements the [`Mul`] trait for the [`Expr`] struct.
437/// 
438/// This implementation allows the multiplication of two [`Expr`] objects.
439/// 
440/// Example:
441/// 
442/// ```rust
443/// use alpha_micrograd_rust::value::Expr;
444/// 
445/// let expr = Expr::new_leaf(1.0, "x");
446/// let expr2 = Expr::new_leaf(2.0, "y");
447/// 
448/// let result = expr * expr2;
449/// 
450/// println!("Result: {}", result.result); // 2.0
451/// ```
452impl Mul for Expr {
453    type Output = Expr;
454
455    fn mul(self, other: Expr) -> Expr {
456        let result = self.result * other.result;
457        let name = &format!("({} * {})", self.name, other.name);
458        Expr::new_binary(self, other, Operation::Mul, result, name)
459    }
460}
461
462/// Implements the [`Mul`] trait for the [`Expr`] struct, when the right operand is a [`f64`].
463/// 
464/// This implementation allows the multiplication of an [`Expr`] object and a [`f64`] value.
465/// 
466/// Example:
467/// 
468/// ```rust
469/// use alpha_micrograd_rust::value::Expr;
470/// 
471/// let expr = Expr::new_leaf(1.0, "x");
472/// let result = expr * 2.0;
473/// 
474/// println!("Result: {}", result.result); // 2.0
475/// ```
476impl Mul<f64> for Expr {
477    type Output = Expr;
478
479    fn mul(self, other: f64) -> Expr {
480        let operand2 = Expr::new_leaf(other, &other.to_string());
481        self * operand2
482    }
483}
484
485/// Implements the [`Mul`] trait for the [`f64`] type, when the right operand is an [`Expr`].
486/// 
487/// This implementation allows the multiplication of a [`f64`] value and an [`Expr`] object.
488/// 
489/// Example:
490/// 
491/// ```rust
492/// use alpha_micrograd_rust::value::Expr;
493/// 
494/// let expr = Expr::new_leaf(1.0, "x");
495/// let result = 2.0 * expr;
496/// 
497/// println!("Result: {}", result.result); // 2.0
498/// ```
499impl Mul<Expr> for f64 {
500    type Output = Expr;
501
502    fn mul(self, other: Expr) -> Expr {
503        let operand1 = Expr::new_leaf(self, &self.to_string());
504        operand1 * other
505    }
506}
507
508/// Implements the [`Sub`] trait for the [`Expr`] struct.
509/// 
510/// This implementation allows the subtraction of two [`Expr`] objects.
511/// 
512/// Example:
513/// 
514/// ```rust
515/// use alpha_micrograd_rust::value::Expr;
516/// 
517/// let expr = Expr::new_leaf(1.0, "x");
518/// let expr2 = Expr::new_leaf(2.0, "y");
519/// 
520/// let result = expr - expr2;
521/// 
522/// println!("Result: {}", result.result); // -1.0
523/// ```
524impl Sub for Expr {
525    type Output = Expr;
526
527    fn sub(self, other: Expr) -> Expr {
528        let result = self.result - other.result;
529        let name = &format!("({} - {})", self.name, other.name);
530        Expr::new_binary(self, other, Operation::Sub, result, name)
531    }
532}
533
534/// Implements the [`Sub`] trait for the [`Expr`] struct, when the right operand is a [`f64`].
535/// 
536/// This implementation allows the subtraction of an [`Expr`] object and a [`f64`] value.
537/// 
538/// Example:
539/// 
540/// ```rust
541/// use alpha_micrograd_rust::value::Expr;
542/// 
543/// let expr = Expr::new_leaf(1.0, "x");
544/// let result = expr - 2.0;
545/// 
546/// println!("Result: {}", result.result); // -1.0
547/// ```
548impl Sub<f64> for Expr {
549    type Output = Expr;
550
551    fn sub(self, other: f64) -> Expr {
552        let operand2 = Expr::new_leaf(other, &other.to_string());
553        self - operand2
554    }
555}
556
557/// Implements the [`Sub`] trait for the [`f64`] type, when the right operand is an [`Expr`].
558/// 
559/// This implementation allows the subtraction of a [`f64`] value and an [`Expr`] object.
560/// 
561/// Example:
562/// 
563/// ```rust
564/// use alpha_micrograd_rust::value::Expr;
565/// 
566/// let expr = Expr::new_leaf(1.0, "x");
567/// let result = 2.0 - expr;
568/// 
569/// println!("Result: {}", result.result); // 1.0
570/// ```
571impl Sub<Expr> for f64 {
572    type Output = Expr;
573
574    fn sub(self, other: Expr) -> Expr {
575        let operand1 = Expr::new_leaf(self, &self.to_string());
576        operand1 - other
577    }
578}
579
580/// Implements the [`Div`] trait for the [`Expr`] struct.
581/// 
582/// This implementation allows the division of two [`Expr`] objects.
583/// 
584/// Example:
585/// 
586/// ```rust
587/// use alpha_micrograd_rust::value::Expr;
588/// 
589/// let expr = Expr::new_leaf(1.0, "x");
590/// let expr2 = Expr::new_leaf(2.0, "y");
591/// 
592/// let result = expr / expr2;
593/// 
594/// println!("Result: {}", result.result); // 0.5
595/// ```
596impl Div for Expr {
597    type Output = Expr;
598
599    fn div(self, other: Expr) -> Expr {
600        let result = self.result / other.result;
601        let name = &format!("({} / {})", self.name, other.name);
602        Expr::new_binary(self, other, Operation::Div, result, name)
603    }
604}
605
606/// Implements the [`Div`] trait for the [`Expr`] struct, when the right operand is a [`f64`].
607/// 
608/// This implementation allows the division of an [`Expr`] object and a [`f64`] value.
609/// 
610/// Example:
611/// 
612/// ```rust
613/// use alpha_micrograd_rust::value::Expr;
614/// 
615/// let expr = Expr::new_leaf(1.0, "x");
616/// let result = expr / 2.0;
617/// 
618/// println!("Result: {}", result.result); // 0.5
619/// ```
620impl Div<f64> for Expr {
621    type Output = Expr;
622
623    fn div(self, other: f64) -> Expr {
624        let operand2 = Expr::new_leaf(other, &other.to_string());
625        self / operand2
626    }
627}
628
629/// Implements the [`Sum`] trait for the [`Expr`] struct.
630/// 
631/// Note that this implementation will generate temporary [`Expr`] objects,
632/// which may not be the most efficient way to sum a collection of [`Expr`] objects.
633/// However, it is provided as a convenience method for users that want to use sum
634/// over an [`Iterator<Expr>`].
635/// 
636/// Example:
637/// 
638/// ```rust
639/// use alpha_micrograd_rust::value::Expr;
640/// 
641/// let expr = Expr::new_leaf(1.0, "x");
642/// let expr2 = Expr::new_leaf(2.0, "y");
643/// let expr3 = Expr::new_leaf(3.0, "z");
644/// 
645/// let sum = vec![expr, expr2, expr3].into_iter().sum::<Expr>();
646/// 
647/// println!("Result: {}", sum.result); // 6.0
648/// ```
649impl Sum for Expr {
650    fn sum<I>(iter: I) -> Self
651    where
652        I: Iterator<Item = Self>,
653    {
654        iter.reduce(|acc, x| acc + x)
655            .unwrap_or(Expr::new_leaf(0.0, "0.0"))
656    }
657}
658
659#[cfg(test)]
660mod tests {
661    use super::*;
662
663    fn assert_float_eq(f1: f64, f2: f64) {
664        let places = 7;
665        let tolerance = 10.0_f64.powi(-places);
666        assert!((f1 - f2).abs() < tolerance, "{} != {} (tol: {})", f1, f2, tolerance);
667    }
668
669    #[test]
670    fn test() {
671        let expr = Expr::new_leaf(1.0, "x");
672        assert_eq!(expr.result, 1.0);
673    }
674
675    #[test]
676    fn test_unary() {
677        let expr = Expr::new_leaf(1.0, "x");
678        let expr2 = Expr::new_unary(expr, Operation::Tanh, 1.1, "tanh(x)");
679
680        assert_eq!(expr2.result, 1.1);
681        assert_eq!(expr2.operand1.unwrap().result, 1.0);
682    }
683
684    #[test]
685    #[should_panic]
686    fn test_unary_expression_type_check() {
687        let expr = Expr::new_leaf(1.0, "x");
688        let _expr2 = Expr::new_unary(expr, Operation::Add, 1.1, "tanh(x)");
689    }
690
691    #[test]
692    fn test_binary() {
693        let expr = Expr::new_leaf(1.0, "x");
694        let expr2 = Expr::new_leaf(2.0, "y");
695        let expr3 = Expr::new_binary(expr, expr2, Operation::Add, 1.1, "x + y");
696
697        assert_eq!(expr3.result, 1.1);
698        assert_eq!(expr3.operand1.unwrap().result, 1.0);
699        assert_eq!(expr3.operand2.unwrap().result, 2.0);
700    }
701
702    #[test]
703    #[should_panic]
704    fn test_binary_expression_type_check() {
705        let expr = Expr::new_leaf(1.0, "x");
706        let expr2 = Expr::new_leaf(2.0, "y");
707        let _expr3 = Expr::new_binary(expr, expr2, Operation::Tanh, 3.0, "x + y");
708    }
709
710    #[test]
711    fn test_mixed_tree() {
712        let expr = Expr::new_leaf(1.0, "x");
713        let expr2 = Expr::new_leaf(2.0, "y");
714        let expr3 = Expr::new_binary(expr, expr2, Operation::Sub, 1.1, "x - y");
715        let expr4 = Expr::new_unary(expr3, Operation::Tanh, 1.2, "tanh(x - y)");
716
717        assert_eq!(expr4.result, 1.2);
718        let expr3 = expr4.operand1.unwrap();
719        assert_eq!(expr3.result, 1.1);
720        assert_eq!(expr3.operand1.unwrap().result, 1.0);
721        assert_eq!(expr3.operand2.unwrap().result, 2.0);
722    }
723
724    #[test]
725    fn test_tanh() {
726        let expr = Expr::new_leaf(1.0, "x");
727        let expr2 = expr.tanh("tanh(x)");
728
729        assert_eq!(expr2.result, 0.7615941559557649);
730        assert!(expr2.operand1.is_some());
731        assert_eq!(expr2.operand1.unwrap().result, 1.0);
732        assert_eq!(expr2.operation, Operation::Tanh);
733        assert!(expr2.operand2.is_none());
734
735        // Some other known values
736        fn get_tanh(x: f64) -> f64 {
737            Expr::new_leaf(x, "x").tanh("tanh(x)").result
738        }
739
740        assert_float_eq(get_tanh(10.74), 0.9999999);
741        assert_float_eq(get_tanh(-10.74), -0.9999999);
742        assert_float_eq(get_tanh(0.0), 0.0);
743    }
744
745    #[test]
746    fn test_exp() {
747        let expr = Expr::new_leaf(1.0, "x");
748        let expr2 = expr.exp("exp(x)");
749
750        assert_eq!(expr2.result, 2.718281828459045);
751        assert!(expr2.operand1.is_some());
752        assert_eq!(expr2.operand1.unwrap().result, 1.0);
753        assert_eq!(expr2.operation, Operation::Exp);
754        assert!(expr2.operand2.is_none());
755    }
756
757    #[test]
758    fn test_relu() {
759        // negative case
760        let expr = Expr::new_leaf(-1.0, "x");
761        let expr2 = expr.relu("relu(x)");
762
763        assert_eq!(expr2.result, 0.0);
764        assert!(expr2.operand1.is_some());
765        assert_eq!(expr2.operand1.unwrap().result, -1.0);
766        assert_eq!(expr2.operation, Operation::ReLU);
767        assert!(expr2.operand2.is_none());
768
769        // positive case
770        let expr = Expr::new_leaf(1.0, "x");
771        let expr2 = expr.relu("relu(x)");
772
773        assert_eq!(expr2.result, 1.0);
774        assert!(expr2.operand1.is_some());
775        assert_eq!(expr2.operand1.unwrap().result, 1.0);
776        assert_eq!(expr2.operation, Operation::ReLU);
777        assert!(expr2.operand2.is_none());
778    }
779
780    #[test]
781    fn test_pow() {
782        let expr = Expr::new_leaf(2.0, "x");
783        let expr2 = Expr::new_leaf(3.0, "y");
784        let result = expr.pow(expr2, "x^y");
785
786        assert_eq!(result.result, 8.0);
787        assert!(result.operand1.is_some());
788        assert_eq!(result.operand1.unwrap().result, 2.0);
789        assert_eq!(result.operation, Operation::Pow);
790        
791        assert!(result.operand2.is_some());
792        assert_eq!(result.operand2.unwrap().result, 3.0);
793    }
794
795    #[test]
796    fn test_add() {
797        let expr = Expr::new_leaf(1.0, "x");
798        let expr2 = Expr::new_leaf(2.0, "y");
799        let expr3 = expr + expr2;
800
801        assert_eq!(expr3.result, 3.0);
802        assert!(expr3.operand1.is_some());
803        assert_eq!(expr3.operand1.unwrap().result, 1.0);
804        assert!(expr3.operand2.is_some());
805        assert_eq!(expr3.operand2.unwrap().result, 2.0);
806        assert_eq!(expr3.operation, Operation::Add);
807        assert_eq!(expr3.name, "(x + y)");
808    }
809
810    #[test]
811    fn test_add_f64() {
812        let expr = Expr::new_leaf(1.0, "x");
813        let expr2 = expr + 2.0;
814
815        assert_eq!(expr2.result, 3.0);
816        assert!(expr2.operand1.is_some());
817        assert_eq!(expr2.operand1.unwrap().result, 1.0);
818        assert!(expr2.operand2.is_some());
819        assert_eq!(expr2.operand2.unwrap().result, 2.0);
820        assert_eq!(expr2.operation, Operation::Add);
821        assert_eq!(expr2.name, "(x + 2)");
822    }
823
824    #[test]
825    fn test_add_f64_expr() {
826        let expr = Expr::new_leaf(1.0, "x");
827        let expr2 = 2.0 + expr;
828
829        assert_eq!(expr2.result, 3.0);
830        assert!(expr2.operand1.is_some());
831        assert_eq!(expr2.operand1.unwrap().result, 2.0);
832        assert!(expr2.operand2.is_some());
833        assert_eq!(expr2.operand2.unwrap().result, 1.0);
834        assert_eq!(expr2.operation, Operation::Add);
835        assert_eq!(expr2.name, "(2 + x)");
836    }
837
838    #[test]
839    fn test_mul() {
840        let expr = Expr::new_leaf(2.0, "x");
841        let expr2 = Expr::new_leaf(3.0, "y");
842        let expr3 = expr * expr2;
843
844        assert_eq!(expr3.result, 6.0);
845        assert!(expr3.operand1.is_some());
846        assert_eq!(expr3.operand1.unwrap().result, 2.0);
847        assert!(expr3.operand2.is_some());
848        assert_eq!(expr3.operand2.unwrap().result, 3.0);
849        assert_eq!(expr3.operation, Operation::Mul);
850        assert_eq!(expr3.name, "(x * y)");
851    }
852
853    #[test]
854    fn test_mul_f64() {
855        let expr = Expr::new_leaf(2.0, "x");
856        let expr2 = expr * 3.0;
857
858        assert_eq!(expr2.result, 6.0);
859        assert!(expr2.operand1.is_some());
860        assert_eq!(expr2.operand1.unwrap().result, 2.0);
861        assert!(expr2.operand2.is_some());
862        assert_eq!(expr2.operand2.unwrap().result, 3.0);
863        assert_eq!(expr2.operation, Operation::Mul);
864        assert_eq!(expr2.name, "(x * 3)");
865    }
866
867    #[test]
868    fn test_mul_f64_expr() {
869        let expr = Expr::new_leaf(2.0, "x");
870        let expr2 = 3.0 * expr;
871
872        assert_eq!(expr2.result, 6.0);
873        assert!(expr2.operand1.is_some());
874        assert_eq!(expr2.operand1.unwrap().result, 3.0);
875        assert!(expr2.operand2.is_some());
876        assert_eq!(expr2.operand2.unwrap().result, 2.0);
877        assert_eq!(expr2.operation, Operation::Mul);
878        assert_eq!(expr2.name, "(3 * x)");
879    }
880
881    #[test]
882    fn test_sub() {
883        let expr = Expr::new_leaf(2.0, "x");
884        let expr2 = Expr::new_leaf(3.0, "y");
885        let expr3 = expr - expr2;
886
887        assert_eq!(expr3.result, -1.0);
888        assert!(expr3.operand1.is_some());
889        assert_eq!(expr3.operand1.unwrap().result, 2.0);
890        assert!(expr3.operand2.is_some());
891        assert_eq!(expr3.operand2.unwrap().result, 3.0);
892        assert_eq!(expr3.operation, Operation::Sub);
893        assert_eq!(expr3.name, "(x - y)");
894    }
895
896    #[test]
897    fn test_sub_f64() {
898        let expr = Expr::new_leaf(2.0, "x");
899        let expr2 = expr - 3.0;
900
901        assert_eq!(expr2.result, -1.0);
902        assert!(expr2.operand1.is_some());
903        assert_eq!(expr2.operand1.unwrap().result, 2.0);
904        assert!(expr2.operand2.is_some());
905        assert_eq!(expr2.operand2.unwrap().result, 3.0);
906        assert_eq!(expr2.operation, Operation::Sub);
907        assert_eq!(expr2.name, "(x - 3)");
908    }
909
910    #[test]
911    fn test_sub_f64_expr() {
912        let expr = Expr::new_leaf(2.0, "x");
913        let expr2 = 3.0 - expr;
914
915        assert_eq!(expr2.result, 1.0);
916        assert!(expr2.operand1.is_some());
917        assert_eq!(expr2.operand1.unwrap().result, 3.0);
918        assert!(expr2.operand2.is_some());
919        assert_eq!(expr2.operand2.unwrap().result, 2.0);
920        assert_eq!(expr2.operation, Operation::Sub);
921        assert_eq!(expr2.name, "(3 - x)");
922    }
923
924    #[test]
925    fn test_div() {
926        let expr = Expr::new_leaf(6.0, "x");
927        let expr2 = Expr::new_leaf(3.0, "y");
928        let expr3 = expr / expr2;
929
930        assert_eq!(expr3.result, 2.0);
931        assert!(expr3.operand1.is_some());
932        assert_eq!(expr3.operand1.unwrap().result, 6.0);
933        assert!(expr3.operand2.is_some());
934        assert_eq!(expr3.operand2.unwrap().result, 3.0);
935        assert_eq!(expr3.operation, Operation::Div);
936        assert_eq!(expr3.name, "(x / y)");
937    }
938
939    #[test]
940    fn test_div_f64() {
941        let expr = Expr::new_leaf(6.0, "x");
942        let expr2 = expr / 3.0;
943
944        assert_eq!(expr2.result, 2.0);
945        assert!(expr2.operand1.is_some());
946        assert_eq!(expr2.operand1.unwrap().result, 6.0);
947        assert!(expr2.operand2.is_some());
948        assert_eq!(expr2.operand2.unwrap().result, 3.0);
949        assert_eq!(expr2.operation, Operation::Div);
950        assert_eq!(expr2.name, "(x / 3)");
951    }
952
953    #[test]
954    fn test_backpropagation_add() {
955        let operand1 = Expr::new_leaf(1.0, "x");
956        let operand2 = Expr::new_leaf(2.0, "y");
957        let mut expr3 = operand1 + operand2;
958
959        expr3.learn(1e-09);
960
961        let operand1 = expr3.operand1.unwrap();
962        let operand2 = expr3.operand2.unwrap();
963        assert_eq!(operand1.grad, 1.0);
964        assert_eq!(operand2.grad, 1.0);
965    }
966
967    #[test]
968    fn test_backpropagation_sub() {
969        let operand1 = Expr::new_leaf(1.0, "x");
970        let operand2 = Expr::new_leaf(2.0, "y");
971        let mut expr3 = operand1 - operand2;
972
973        expr3.learn(1e-09);
974
975        let operand1 = expr3.operand1.unwrap();
976        let operand2 = expr3.operand2.unwrap();
977        assert_eq!(operand1.grad, 1.0);
978        assert_eq!(operand2.grad, -1.0);
979    }
980
981    #[test]
982    fn test_backpropagation_mul() {
983        let operand1 = Expr::new_leaf(3.0, "x");
984        let operand2 = Expr::new_leaf(4.0, "y");
985        let mut expr3 = operand1 * operand2;
986
987        expr3.learn(1e-09);
988
989        let operand1 = expr3.operand1.unwrap();
990        let operand2 = expr3.operand2.unwrap();
991        assert_eq!(operand1.grad, 4.0);
992        assert_eq!(operand2.grad, 3.0);
993    }
994
995    #[test]
996    fn test_backpropagation_div() {
997        let operand1 = Expr::new_leaf(3.0, "x");
998        let operand2 = Expr::new_leaf(4.0, "y");
999        let mut expr3 = operand1 / operand2;
1000
1001        expr3.learn(1e-09);
1002
1003        let operand1 = expr3.operand1.unwrap();
1004        let operand2 = expr3.operand2.unwrap();
1005        assert_eq!(operand1.grad, 0.25);
1006        assert_eq!(operand2.grad, -0.1875);
1007    }
1008
1009    #[test]
1010    fn test_backpropagation_tanh() {
1011        let operand1 = Expr::new_leaf(0.0, "x");
1012        let mut expr2 = operand1.tanh("tanh(x)");
1013
1014        expr2.learn(1e-09);
1015
1016        let operand1 = expr2.operand1.unwrap();
1017        assert_eq!(operand1.grad, 1.0);
1018    }
1019
1020    #[test]
1021    fn test_backpropagation_relu() {
1022        let operand1 = Expr::new_leaf(-1.0, "x");
1023        let mut expr2 = operand1.relu("relu(x)");
1024
1025        expr2.learn(1e-09);
1026
1027        let operand1 = expr2.operand1.unwrap();
1028        assert_eq!(operand1.grad, 0.0);
1029    }
1030
1031    #[test]
1032    fn test_backpropagation_exp() {
1033        let operand1 = Expr::new_leaf(0.0, "x");
1034        let mut expr2 = operand1.exp("exp(x)");
1035
1036        expr2.learn(1e-09);
1037
1038        let operand1 = expr2.operand1.unwrap();
1039        assert_eq!(operand1.grad, 1.0);
1040    }
1041
1042    #[test]
1043    fn test_backpropagation_pow() {
1044        let operand1 = Expr::new_leaf(2.0, "x");
1045        let operand2 = Expr::new_leaf(3.0, "y");
1046        let mut expr3 = operand1.pow(operand2, "x^y");
1047
1048        expr3.learn(1e-09);
1049
1050        let operand1 = expr3.operand1.unwrap();
1051        let operand2 = expr3.operand2.unwrap();
1052        assert_eq!(operand1.grad, 12.0);
1053        assert_eq!(operand2.grad, 5.545177444479562);
1054    }
1055
1056    #[test]
1057    fn test_backpropagation_mixed_tree() {
1058        let operand1 = Expr::new_leaf(1.0, "x");
1059        let operand2 = Expr::new_leaf(2.0, "y");
1060        let expr3 = operand1 + operand2;
1061        let mut expr4 = expr3.tanh("tanh(x + y)");
1062
1063        expr4.learn(1e-09);
1064
1065        let expr3 = expr4.operand1.unwrap();
1066        let operand1 = expr3.operand1.unwrap();
1067        let operand2 = expr3.operand2.unwrap();
1068
1069        assert_eq!(expr3.grad, 0.009866037165440211);
1070        assert_eq!(operand1.grad, 0.009866037165440211);
1071        assert_eq!(operand2.grad, 0.009866037165440211);
1072    }
1073
1074    #[test]
1075    fn test_backpropagation_karpathys_example() {
1076        let x1 = Expr::new_leaf(2.0, "x1");
1077        let x2 = Expr::new_leaf(0.0, "x2");
1078        let w1 = Expr::new_leaf(-3.0, "w1");
1079        let w2 = Expr::new_leaf(1.0, "w2");
1080        let b = Expr::new_leaf(6.8813735870195432, "b");
1081
1082        let x1w1 = x1 * w1;
1083        let x2w2 = x2 * w2;
1084        let x1w1_x2w2 = x1w1 + x2w2;
1085        let n = x1w1_x2w2 + b;
1086        let mut o = n.tanh("tanh(n)");
1087
1088        o.learn(1e-09);
1089
1090        assert_eq!(o.operation, Operation::Tanh);
1091        assert_eq!(o.grad, 1.0);
1092
1093        let n = o.operand1.unwrap();
1094        assert_eq!(n.operation, Operation::Add);
1095        assert_float_eq(n.grad, 0.5);
1096
1097        let x1w1_x2w2 = n.operand1.unwrap();
1098        assert_eq!(x1w1_x2w2.operation, Operation::Add);
1099        assert_float_eq(x1w1_x2w2.grad, 0.5);
1100
1101        let b = n.operand2.unwrap();
1102        assert_eq!(b.operation, Operation::None);
1103        assert_float_eq(b.grad, 0.5);
1104
1105        let x1w1 = x1w1_x2w2.operand1.unwrap();
1106        assert_eq!(x1w1.operation, Operation::Mul);
1107        assert_float_eq(x1w1.grad, 0.5);
1108
1109        let x2w2 = x1w1_x2w2.operand2.unwrap();
1110        assert_eq!(x2w2.operation, Operation::Mul);
1111        assert_float_eq(x2w2.grad, 0.5);
1112
1113        let x1 = x1w1.operand1.unwrap();
1114        assert_eq!(x1.operation, Operation::None);
1115        assert_float_eq(x1.grad, -1.5);
1116
1117        let w1 = x1w1.operand2.unwrap();
1118        assert_eq!(w1.operation, Operation::None);
1119        assert_float_eq(w1.grad, 1.0);
1120
1121        let x2 = x2w2.operand1.unwrap();
1122        assert_eq!(x2.operation, Operation::None);
1123        assert_float_eq(x2.grad, 0.5);
1124
1125        let w2 = x2w2.operand2.unwrap();
1126        assert_eq!(w2.operation, Operation::None);
1127        assert_float_eq(w2.grad, 0.0);
1128    }
1129
1130    #[test]
1131    fn test_learn_simple() {
1132        let mut expr = Expr::new_leaf(1.0, "x");
1133        expr.learn(1e-01);
1134
1135        assert_float_eq(expr.result, 0.9);
1136    }
1137
1138    #[test]
1139    fn test_learn_skips_non_learnable() {
1140        let mut expr = Expr::new_leaf(1.0, "x");
1141        expr.is_learnable = false;
1142        expr.learn(1e-01);
1143
1144        assert_float_eq(expr.result, 1.0);
1145    }
1146
1147    #[test]
1148    fn test_find_simple() {
1149        let expr = Expr::new_leaf(1.0, "x");
1150        let expr2 = expr.tanh("tanh(x)");
1151
1152        let found = expr2.find("x");
1153        assert!(found.is_some());
1154        assert_eq!(found.unwrap().name, "x");
1155    }
1156
1157    #[test]
1158    fn test_find_not_found() {
1159        let expr = Expr::new_leaf(1.0, "x");
1160        let expr2 = expr.tanh("tanh(x)");
1161
1162        let found = expr2.find("y");
1163        assert!(found.is_none());
1164    }
1165
1166    #[test]
1167    fn test_sum_iterator() {
1168        let expr = Expr::new_leaf(1.0, "x");
1169        let expr2 = Expr::new_leaf(2.0, "y");
1170        let expr3 = Expr::new_leaf(3.0, "z");
1171
1172        let sum: Expr = vec![expr, expr2, expr3].into_iter().sum::<Expr>();
1173        assert_eq!(sum.result, 6.0);
1174    }
1175
1176    #[test]
1177    fn test_find_after_clone() {
1178        let expr = Expr::new_leaf(1.0, "x");
1179        let expr2 = expr.tanh("tanh(x)");
1180        let expr2_clone = expr2.clone();
1181
1182        let found = expr2_clone.find("x");
1183        assert!(found.is_some());
1184        assert_eq!(found.unwrap().name, "x");
1185    }
1186}