alpha_micrograd_rust/
value.rs

1//! A simple library for creating and backpropagating through expression trees.
2//! 
3//! This package includes the following elements to construct expression trees:
4//! - [`Expr`]: a node in the expression tree
5#![deny(missing_docs)]
6use std::ops::{Add, Div, Mul, Sub};
7use std::iter::Sum;
8
9#[derive(Debug, Clone, PartialEq)]
10enum Operation {
11    None,
12    Add,
13    Sub,
14    Mul,
15    Div,
16    Tanh,
17    Exp,
18    Pow,
19    ReLU,
20    Log,
21    Neg,
22}
23
24impl Operation {
25    fn assert_is_type(&self, expr_type: ExprType) {
26        match self {
27            Operation::None => assert_eq!(expr_type, ExprType::Leaf),
28            Operation::Tanh | Operation::Exp | Operation::ReLU | Operation::Log | Operation::Neg => assert_eq!(expr_type, ExprType::Unary),
29            _ => assert_eq!(expr_type, ExprType::Binary),
30        }
31    }
32}
33
34#[derive(Debug, PartialEq)]
35enum ExprType {
36    Leaf,
37    Unary,
38    Binary,
39}
40
41/// Expression representing a node in a calculation graph.
42/// 
43/// This struct represents a node in a calculation graph. It can be a leaf node, a unary operation or a binary operation.
44/// 
45/// A leaf node holds a value, which is the one that is used in the calculation.
46/// 
47/// A unary expression is the result of applying a unary operation to another expression. For example, the result of applying the `tanh` operation to a leaf node.
48/// 
49/// A binary expression is the result of applying a binary operation to two other expressions. For example, the result of adding two leaf nodes.
50#[derive(Debug, Clone)]
51pub struct Expr {
52    operand1: Option<Box<Expr>>,
53    operand2: Option<Box<Expr>>,
54    operation: Operation,
55    /// The numeric result of the expression, as result of applying the operation to the operands.
56    pub result: f64,
57    /// Whether the expression is learnable or not. Only learnable [`Expr`] will have their values updated during backpropagation (learning).
58    pub is_learnable: bool,
59    grad: f64,
60    /// The name of the expression, used to identify it in the calculation graph.
61    pub name: String,
62}
63
64impl Expr {
65    /// Creates a new leaf expression with the given value.
66    /// 
67    /// Example:
68    /// ```rust
69    /// use alpha_micrograd_rust::value::Expr;
70    /// 
71    /// let expr = Expr::new_leaf(1.0, "x");
72    /// ```
73    pub fn new_leaf(value: f64, name: &str) -> Expr {
74        Expr {
75            operand1: None,
76            operand2: None,
77            operation: Operation::None,
78            result: value,
79            is_learnable: true,
80            grad: 0.0,
81            name: name.to_string(),
82        }
83    }
84
85    fn expr_type(&self) -> ExprType {
86        match self.operation {
87            Operation::None => ExprType::Leaf,
88            Operation::Tanh | Operation::Exp | Operation::ReLU | Operation::Log | Operation::Neg => ExprType::Unary,
89            _ => ExprType::Binary,
90        }
91    }
92
93    fn new_unary(operand: Expr, operation: Operation, result: f64, name: &str) -> Expr {
94        operation.assert_is_type(ExprType::Unary);
95        Expr {
96            operand1: Some(Box::new(operand)),
97            operand2: None,
98            operation,
99            result,
100            is_learnable: false,
101            grad: 0.0,
102            name: name.to_string(),
103        }
104    }
105
106    fn new_binary(operand1: Expr, operand2: Expr, operation: Operation, result: f64, name: &str) -> Expr {
107        operation.assert_is_type(ExprType::Binary);
108        Expr {
109            operand1: Some(Box::new(operand1)),
110            operand2: Some(Box::new(operand2)),
111            operation,
112            result,
113            is_learnable: false,
114            grad: 0.0,
115            name: name.to_string(),
116        }
117    }
118
119    /// Applies the hyperbolic tangent function to the expression and returns it as a new expression.
120    /// 
121    /// Example:
122    /// ```rust
123    /// use alpha_micrograd_rust::value::Expr;
124    /// 
125    /// let expr = Expr::new_leaf(1.0, "x");
126    /// let expr2 = expr.tanh("tanh");
127    /// 
128    /// println!("Result: {}", expr2.result); // 0.7615941559557649
129    /// ```
130    pub fn tanh(self, name: &str) -> Expr {
131        let result = self.result.tanh();
132        Expr::new_unary(self, Operation::Tanh, result, name)
133    }
134
135    /// Applies the rectified linear unit function to the expression and returns it as a new expression.
136    /// 
137    /// Example:
138    /// ```rust
139    /// use alpha_micrograd_rust::value::Expr;
140    /// 
141    /// let expr = Expr::new_leaf(-1.0, "x");
142    /// let expr2 = expr.relu("relu");
143    /// 
144    /// println!("Result: {}", expr2.result); // 0.0
145    /// ```
146    pub fn relu(self, name: &str) -> Expr {
147        let result = self.result.max(0.0);
148        Expr::new_unary(self, Operation::ReLU, result, name)
149    }
150
151    /// Applies the exponential function (e^x) to the expression and returns it as a new expression.
152    /// 
153    /// Example:
154    /// ```rust
155    /// use alpha_micrograd_rust::value::Expr;
156    /// 
157    /// let expr = Expr::new_leaf(1.0, "x");
158    /// let expr2 = expr.exp("exp");
159    /// 
160    /// println!("Result: {}", expr2.result); // 2.718281828459045
161    /// ```
162    pub fn exp(self, name: &str) -> Expr {
163        let result = self.result.exp();
164        Expr::new_unary(self, Operation::Exp, result, name)
165    }
166
167    /// Raises the expression to the power of the given exponent (expression) and returns it as a new expression.
168    /// 
169    /// Example:
170    /// ```rust
171    /// use alpha_micrograd_rust::value::Expr;
172    /// 
173    /// let expr = Expr::new_leaf(2.0, "x");
174    /// let exponent = Expr::new_leaf(3.0, "y");
175    /// let result = expr.pow(exponent, "x^y");
176    /// 
177    /// println!("Result: {}", result.result); // 8.0
178    /// ```
179    pub fn pow(self, exponent: Expr, name: &str) -> Expr {
180        let result = self.result.powf(exponent.result);
181        Expr::new_binary(self, exponent, Operation::Pow, result, name)
182    }
183
184    /// Applies the natural logarithm function to the expression and returns it as a new expression.
185    /// 
186    /// Example:
187    /// ```rust
188    /// use alpha_micrograd_rust::value::Expr;
189    /// 
190    /// let expr = Expr::new_leaf(2.0, "x");
191    /// let expr2 = expr.log("log");
192    /// 
193    /// println!("Result: {}", expr2.result); // 0.6931471805599453
194    /// ```
195    pub fn log(self, name: &str) -> Expr {
196        let result = self.result.ln();
197        Expr::new_unary(self, Operation::Log, result, name)
198    }
199
200    /// Negates the expression and returns it as a new expression.
201    /// 
202    /// Example:
203    /// ```rust
204    /// use alpha_micrograd_rust::value::Expr;
205    /// 
206    /// let expr = Expr::new_leaf(1.0, "x");
207    /// let expr2 = expr.neg("neg");
208    /// 
209    /// println!("Result: {}", expr2.result); // -1.0
210    /// ```
211    pub fn neg(self, name: &str) -> Expr {
212        let result = -self.result;
213        Expr::new_unary(self, Operation::Neg, result, name)
214    }
215
216    /// Recalculates the value of the expression recursively, from new values of the operands.
217    /// 
218    /// Usually will be used after a call to [`Expr::learn`], where the gradients have been calculated and
219    /// the internal values of the expression tree have been updated.
220    /// 
221    /// Example:
222    /// ```rust
223    /// use alpha_micrograd_rust::value::Expr;
224    /// 
225    /// let expr = Expr::new_leaf(1.0, "x");
226    /// let mut expr2 = expr.tanh("tanh(x)");
227    /// expr2.learn(1e-09);
228    /// expr2.recalculate();
229    /// ```
230    pub fn recalculate(&mut self) {
231        match self.expr_type() {
232            ExprType::Leaf => {}
233            ExprType::Unary => {
234                let operand1 = self.operand1.as_mut().expect("Unary expression did not have an operand");
235                operand1.recalculate();
236
237                self.result = match self.operation {
238                    Operation::Tanh => operand1.result.tanh(),
239                    Operation::Exp => operand1.result.exp(),
240                    Operation::ReLU => operand1.result.max(0.0),
241                    Operation::Log => operand1.result.ln(),
242                    Operation::Neg => -operand1.result,
243                    _ => panic!("Invalid unary operation {:?}", self.operation),
244                };
245            }
246            ExprType::Binary => {
247                let operand1 = self.operand1.as_mut().expect("Binary expression did not have an operand");
248                let operand2 = self.operand2.as_mut().expect("Binary expression did not have a second operand");
249
250                operand1.recalculate();
251                operand2.recalculate();
252
253                self.result = match self.operation {
254                    Operation::Add => operand1.result + operand2.result,
255                    Operation::Sub => operand1.result - operand2.result,
256                    Operation::Mul => operand1.result * operand2.result,
257                    Operation::Div => operand1.result / operand2.result,
258                    Operation::Pow => operand1.result.powf(operand2.result),
259                    _ => panic!("Invalid binary operation: {:?}", self.operation),
260                };
261            }
262        }
263    }
264
265    /// Applies backpropagation to the expression, updating the values of the
266    /// gradients and the expression itself.
267    /// 
268    /// This method will change the gradients based on the gradient of the last
269    /// expression in the calculation graph.
270    /// 
271    /// Example:
272    /// 
273    /// ```rust
274    /// use alpha_micrograd_rust::value::Expr;
275    /// 
276    /// let expr = Expr::new_leaf(1.0, "x");
277    /// let mut expr2 = expr.tanh("tanh(x)");
278    /// expr2.learn(1e-09);
279    /// ```
280    /// 
281    /// After adjusting the gradients, the method will update the values of the
282    /// individual expression tree nodes to minimize the loss function.
283    /// 
284    /// In order to get a new calculation of the expression tree, you'll need to call
285    /// [`Expr::recalculate`] after calling [`Expr::learn`].
286    pub fn learn(&mut self, learning_rate: f64) {
287        self.grad = 1.0;
288        self.learn_internal(learning_rate);
289    }
290
291    fn learn_internal(&mut self, learning_rate: f64) {
292        match self.expr_type() {
293            ExprType::Leaf => {
294                // leaves have their gradient set externally by other nodes in the tree
295                // leaves can be learnable, in which case we update the value
296                if self.is_learnable {
297                    self.result -= learning_rate * self.grad;
298                 }
299            }
300            ExprType::Unary => {
301                let operand1 = self.operand1.as_mut().expect("Unary expression did not have an operand");
302
303                match self.operation {
304                    Operation::Tanh => {
305                        let tanh_grad = 1.0 - (self.result * self.result);
306                        operand1.grad = self.grad * tanh_grad;
307                    }
308                    Operation::Exp => {
309                        operand1.grad = self.grad * self.result;
310                    }
311                    Operation::ReLU => {
312                        operand1.grad = self.grad * if self.result > 0.0 { 1.0 } else { 0.0 };
313                    }
314                    Operation::Log => {
315                        operand1.grad = self.grad / operand1.result;
316                    }
317                    Operation::Neg => {
318                        operand1.grad = -self.grad;
319                    }
320                    _ => panic!("Invalid unary operation {:?}", self.operation),
321                }
322
323                operand1.learn_internal(learning_rate);
324            }
325            ExprType::Binary => {
326                let operand1 = self.operand1.as_mut().expect("Binary expression did not have an operand");
327                let operand2 = self.operand2.as_mut().expect("Binary expression did not have a second operand");
328
329                match self.operation {
330                    Operation::Add => {
331                        operand1.grad = self.grad;
332                        operand2.grad = self.grad;
333                    }
334                    Operation::Sub => {
335                        operand1.grad = self.grad;
336                        operand2.grad = -self.grad;
337                    }
338                    Operation::Mul => {
339                        let operand2_result = operand2.result;
340                        let operand1_result = operand1.result;
341
342                        operand1.grad = self.grad * operand2_result;
343                        operand2.grad = self.grad * operand1_result;
344                    }
345                    Operation::Div => {
346                        let operand2_result = operand2.result;
347                        let operand1_result = operand1.result;
348
349                        operand1.grad = self.grad / operand2_result;
350                        operand2.grad = -self.grad * operand1_result / (operand2_result * operand2_result);
351                    }
352                    Operation::Pow => {
353                        let exponent = operand2.result;
354                        let base = operand1.result;
355
356                        operand1.grad = self.grad * exponent * base.powf(exponent - 1.0);
357                        operand2.grad = self.grad * base.powf(exponent) * base.ln();
358                    }
359                    _ => panic!("Invalid binary operation: {:?}", self.operation),
360                }
361
362                operand1.learn_internal(learning_rate);
363                operand2.learn_internal(learning_rate);
364            }
365        }
366    }
367
368    /// Finds a node in the expression tree by its name.
369    /// 
370    /// This method will search the expression tree for a node with the given name.
371    /// If the node is not found, it will return [None].
372    /// 
373    /// Example:
374    /// ```rust
375    /// use alpha_micrograd_rust::value::Expr;
376    /// 
377    /// let expr = Expr::new_leaf(1.0, "x");
378    /// let expr2 = expr.tanh("tanh(x)");
379    /// let original = expr2.find("x");
380    /// 
381    /// assert_eq!(original.expect("Could not find x").result, 1.0);
382    /// ```
383    pub fn find(&self, name: &str) -> Option<&Expr> {
384        if self.name == name {
385            return Some(self);
386        }
387
388        match self.expr_type() {
389            ExprType::Leaf => None,
390            ExprType::Unary => {
391                let operand1 = self.operand1.as_ref().expect("Unary expression did not have an operand");
392                operand1.find(name)
393            }
394            ExprType::Binary => {
395                let operand1 = self.operand1.as_ref().expect("Binary expression did not have an operand");
396                let operand2 = self.operand2.as_ref().expect("Binary expression did not have a second operand");
397
398                let result = operand1.find(name);
399                if result.is_some() {
400                    return result;
401                }
402
403                operand2.find(name)
404            }
405        }
406    }
407}
408
409/// Implements the [`Add`] trait for the [`Expr`] struct.
410/// 
411/// This implementation allows the addition of two [`Expr`] objects.
412/// 
413/// Example:
414/// ```rust
415/// use alpha_micrograd_rust::value::Expr;
416/// 
417/// let expr = Expr::new_leaf(1.0, "x");
418/// let expr2 = Expr::new_leaf(2.0, "y");
419/// 
420/// let result = expr + expr2;
421/// 
422/// println!("Result: {}", result.result); // 3.0
423/// ```
424impl Add for Expr {
425    type Output = Expr;
426
427    fn add(self, other: Expr) -> Expr {
428        let result = self.result + other.result;
429        let name = &format!("({} + {})", self.name, other.name);
430        Expr::new_binary(self, other, Operation::Add, result, name)
431    }
432}
433
434/// Implements the [`Add`] trait for the [`Expr`] struct, when the right operand is a [`f64`].
435/// 
436/// This implementation allows the addition of an [`Expr`] object and a [`f64`] value.
437/// 
438/// Example:
439/// ```rust
440/// use alpha_micrograd_rust::value::Expr;
441/// 
442/// let expr = Expr::new_leaf(1.0, "x");
443/// let result = expr + 2.0;
444/// 
445/// println!("Result: {}", result.result); // 3.0
446/// ```
447impl Add<f64> for Expr {
448    type Output = Expr;
449
450    fn add(self, other: f64) -> Expr {
451        let operand2 = Expr::new_leaf(other, &other.to_string());
452        self + operand2
453    }
454}
455
456/// Implements the [`Add`] trait for the [`f64`] type, when the right operand is an [`Expr`].
457/// 
458/// This implementation allows the addition of a [`f64`] value and an [`Expr`] object.
459/// 
460/// Example:
461/// ```rust
462/// use alpha_micrograd_rust::value::Expr;
463/// 
464/// let expr = Expr::new_leaf(1.0, "x");
465/// let result = 2.0 + expr;
466/// 
467/// println!("Result: {}", result.result); // 3.0
468/// ```
469impl Add<Expr> for f64 {
470    type Output = Expr;
471
472    fn add(self, other: Expr) -> Expr {
473        let operand1 = Expr::new_leaf(self, &self.to_string());
474        operand1 + other
475    }
476}
477
478/// Implements the [`Mul`] trait for the [`Expr`] struct.
479/// 
480/// This implementation allows the multiplication of two [`Expr`] objects.
481/// 
482/// Example:
483/// 
484/// ```rust
485/// use alpha_micrograd_rust::value::Expr;
486/// 
487/// let expr = Expr::new_leaf(1.0, "x");
488/// let expr2 = Expr::new_leaf(2.0, "y");
489/// 
490/// let result = expr * expr2;
491/// 
492/// println!("Result: {}", result.result); // 2.0
493/// ```
494impl Mul for Expr {
495    type Output = Expr;
496
497    fn mul(self, other: Expr) -> Expr {
498        let result = self.result * other.result;
499        let name = &format!("({} * {})", self.name, other.name);
500        Expr::new_binary(self, other, Operation::Mul, result, name)
501    }
502}
503
504/// Implements the [`Mul`] trait for the [`Expr`] struct, when the right operand is a [`f64`].
505/// 
506/// This implementation allows the multiplication of an [`Expr`] object and a [`f64`] value.
507/// 
508/// Example:
509/// 
510/// ```rust
511/// use alpha_micrograd_rust::value::Expr;
512/// 
513/// let expr = Expr::new_leaf(1.0, "x");
514/// let result = expr * 2.0;
515/// 
516/// println!("Result: {}", result.result); // 2.0
517/// ```
518impl Mul<f64> for Expr {
519    type Output = Expr;
520
521    fn mul(self, other: f64) -> Expr {
522        let operand2 = Expr::new_leaf(other, &other.to_string());
523        self * operand2
524    }
525}
526
527/// Implements the [`Mul`] trait for the [`f64`] type, when the right operand is an [`Expr`].
528/// 
529/// This implementation allows the multiplication of a [`f64`] value and an [`Expr`] object.
530/// 
531/// Example:
532/// 
533/// ```rust
534/// use alpha_micrograd_rust::value::Expr;
535/// 
536/// let expr = Expr::new_leaf(1.0, "x");
537/// let result = 2.0 * expr;
538/// 
539/// println!("Result: {}", result.result); // 2.0
540/// ```
541impl Mul<Expr> for f64 {
542    type Output = Expr;
543
544    fn mul(self, other: Expr) -> Expr {
545        let operand1 = Expr::new_leaf(self, &self.to_string());
546        operand1 * other
547    }
548}
549
550/// Implements the [`Sub`] trait for the [`Expr`] struct.
551/// 
552/// This implementation allows the subtraction of two [`Expr`] objects.
553/// 
554/// Example:
555/// 
556/// ```rust
557/// use alpha_micrograd_rust::value::Expr;
558/// 
559/// let expr = Expr::new_leaf(1.0, "x");
560/// let expr2 = Expr::new_leaf(2.0, "y");
561/// 
562/// let result = expr - expr2;
563/// 
564/// println!("Result: {}", result.result); // -1.0
565/// ```
566impl Sub for Expr {
567    type Output = Expr;
568
569    fn sub(self, other: Expr) -> Expr {
570        let result = self.result - other.result;
571        let name = &format!("({} - {})", self.name, other.name);
572        Expr::new_binary(self, other, Operation::Sub, result, name)
573    }
574}
575
576/// Implements the [`Sub`] trait for the [`Expr`] struct, when the right operand is a [`f64`].
577/// 
578/// This implementation allows the subtraction of an [`Expr`] object and a [`f64`] value.
579/// 
580/// Example:
581/// 
582/// ```rust
583/// use alpha_micrograd_rust::value::Expr;
584/// 
585/// let expr = Expr::new_leaf(1.0, "x");
586/// let result = expr - 2.0;
587/// 
588/// println!("Result: {}", result.result); // -1.0
589/// ```
590impl Sub<f64> for Expr {
591    type Output = Expr;
592
593    fn sub(self, other: f64) -> Expr {
594        let operand2 = Expr::new_leaf(other, &other.to_string());
595        self - operand2
596    }
597}
598
599/// Implements the [`Sub`] trait for the [`f64`] type, when the right operand is an [`Expr`].
600/// 
601/// This implementation allows the subtraction of a [`f64`] value and an [`Expr`] object.
602/// 
603/// Example:
604/// 
605/// ```rust
606/// use alpha_micrograd_rust::value::Expr;
607/// 
608/// let expr = Expr::new_leaf(1.0, "x");
609/// let result = 2.0 - expr;
610/// 
611/// println!("Result: {}", result.result); // 1.0
612/// ```
613impl Sub<Expr> for f64 {
614    type Output = Expr;
615
616    fn sub(self, other: Expr) -> Expr {
617        let operand1 = Expr::new_leaf(self, &self.to_string());
618        operand1 - other
619    }
620}
621
622/// Implements the [`Div`] trait for the [`Expr`] struct.
623/// 
624/// This implementation allows the division of two [`Expr`] objects.
625/// 
626/// Example:
627/// 
628/// ```rust
629/// use alpha_micrograd_rust::value::Expr;
630/// 
631/// let expr = Expr::new_leaf(1.0, "x");
632/// let expr2 = Expr::new_leaf(2.0, "y");
633/// 
634/// let result = expr / expr2;
635/// 
636/// println!("Result: {}", result.result); // 0.5
637/// ```
638impl Div for Expr {
639    type Output = Expr;
640
641    fn div(self, other: Expr) -> Expr {
642        let result = self.result / other.result;
643        let name = &format!("({} / {})", self.name, other.name);
644        Expr::new_binary(self, other, Operation::Div, result, name)
645    }
646}
647
648/// Implements the [`Div`] trait for the [`Expr`] struct, when the right operand is a [`f64`].
649/// 
650/// This implementation allows the division of an [`Expr`] object and a [`f64`] value.
651/// 
652/// Example:
653/// 
654/// ```rust
655/// use alpha_micrograd_rust::value::Expr;
656/// 
657/// let expr = Expr::new_leaf(1.0, "x");
658/// let result = expr / 2.0;
659/// 
660/// println!("Result: {}", result.result); // 0.5
661/// ```
662impl Div<f64> for Expr {
663    type Output = Expr;
664
665    fn div(self, other: f64) -> Expr {
666        let operand2 = Expr::new_leaf(other, &other.to_string());
667        self / operand2
668    }
669}
670
671/// Implements the [`Sum`] trait for the [`Expr`] struct.
672/// 
673/// Note that this implementation will generate temporary [`Expr`] objects,
674/// which may not be the most efficient way to sum a collection of [`Expr`] objects.
675/// However, it is provided as a convenience method for users that want to use sum
676/// over an [`Iterator<Expr>`].
677/// 
678/// Example:
679/// 
680/// ```rust
681/// use alpha_micrograd_rust::value::Expr;
682/// 
683/// let expr = Expr::new_leaf(1.0, "x");
684/// let expr2 = Expr::new_leaf(2.0, "y");
685/// let expr3 = Expr::new_leaf(3.0, "z");
686/// 
687/// let sum = vec![expr, expr2, expr3].into_iter().sum::<Expr>();
688/// 
689/// println!("Result: {}", sum.result); // 6.0
690/// ```
691impl Sum for Expr {
692    fn sum<I>(iter: I) -> Self
693    where
694        I: Iterator<Item = Self>,
695    {
696        iter.reduce(|acc, x| acc + x)
697            .unwrap_or(Expr::new_leaf(0.0, "0.0"))
698    }
699}
700
701#[cfg(test)]
702mod tests {
703    use super::*;
704
705    fn assert_float_eq(f1: f64, f2: f64) {
706        let places = 7;
707        let tolerance = 10.0_f64.powi(-places);
708        assert!((f1 - f2).abs() < tolerance, "{} != {} (tol: {})", f1, f2, tolerance);
709    }
710
711    #[test]
712    fn test() {
713        let expr = Expr::new_leaf(1.0, "x");
714        assert_eq!(expr.result, 1.0);
715    }
716
717    #[test]
718    fn test_unary() {
719        let expr = Expr::new_leaf(1.0, "x");
720        let expr2 = Expr::new_unary(expr, Operation::Tanh, 1.1, "tanh(x)");
721
722        assert_eq!(expr2.result, 1.1);
723        assert_eq!(expr2.operand1.unwrap().result, 1.0);
724    }
725
726    #[test]
727    #[should_panic]
728    fn test_unary_expression_type_check() {
729        let expr = Expr::new_leaf(1.0, "x");
730        let _expr2 = Expr::new_unary(expr, Operation::Add, 1.1, "tanh(x)");
731    }
732
733    #[test]
734    fn test_binary() {
735        let expr = Expr::new_leaf(1.0, "x");
736        let expr2 = Expr::new_leaf(2.0, "y");
737        let expr3 = Expr::new_binary(expr, expr2, Operation::Add, 1.1, "x + y");
738
739        assert_eq!(expr3.result, 1.1);
740        assert_eq!(expr3.operand1.unwrap().result, 1.0);
741        assert_eq!(expr3.operand2.unwrap().result, 2.0);
742    }
743
744    #[test]
745    #[should_panic]
746    fn test_binary_expression_type_check() {
747        let expr = Expr::new_leaf(1.0, "x");
748        let expr2 = Expr::new_leaf(2.0, "y");
749        let _expr3 = Expr::new_binary(expr, expr2, Operation::Tanh, 3.0, "x + y");
750    }
751
752    #[test]
753    fn test_mixed_tree() {
754        let expr = Expr::new_leaf(1.0, "x");
755        let expr2 = Expr::new_leaf(2.0, "y");
756        let expr3 = Expr::new_binary(expr, expr2, Operation::Sub, 1.1, "x - y");
757        let expr4 = Expr::new_unary(expr3, Operation::Tanh, 1.2, "tanh(x - y)");
758
759        assert_eq!(expr4.result, 1.2);
760        let expr3 = expr4.operand1.unwrap();
761        assert_eq!(expr3.result, 1.1);
762        assert_eq!(expr3.operand1.unwrap().result, 1.0);
763        assert_eq!(expr3.operand2.unwrap().result, 2.0);
764    }
765
766    #[test]
767    fn test_tanh() {
768        let expr = Expr::new_leaf(1.0, "x");
769        let expr2 = expr.tanh("tanh(x)");
770
771        assert_eq!(expr2.result, 0.7615941559557649);
772        assert!(expr2.operand1.is_some());
773        assert_eq!(expr2.operand1.unwrap().result, 1.0);
774        assert_eq!(expr2.operation, Operation::Tanh);
775        assert!(expr2.operand2.is_none());
776
777        // Some other known values
778        fn get_tanh(x: f64) -> f64 {
779            Expr::new_leaf(x, "x").tanh("tanh(x)").result
780        }
781
782        assert_float_eq(get_tanh(10.74), 0.9999999);
783        assert_float_eq(get_tanh(-10.74), -0.9999999);
784        assert_float_eq(get_tanh(0.0), 0.0);
785    }
786
787    #[test]
788    fn test_exp() {
789        let expr = Expr::new_leaf(1.0, "x");
790        let expr2 = expr.exp("exp(x)");
791
792        assert_eq!(expr2.result, 2.718281828459045);
793        assert!(expr2.operand1.is_some());
794        assert_eq!(expr2.operand1.unwrap().result, 1.0);
795        assert_eq!(expr2.operation, Operation::Exp);
796        assert!(expr2.operand2.is_none());
797    }
798
799    #[test]
800    fn test_relu() {
801        // negative case
802        let expr = Expr::new_leaf(-1.0, "x");
803        let expr2 = expr.relu("relu(x)");
804
805        assert_eq!(expr2.result, 0.0);
806        assert!(expr2.operand1.is_some());
807        assert_eq!(expr2.operand1.unwrap().result, -1.0);
808        assert_eq!(expr2.operation, Operation::ReLU);
809        assert!(expr2.operand2.is_none());
810
811        // positive case
812        let expr = Expr::new_leaf(1.0, "x");
813        let expr2 = expr.relu("relu(x)");
814
815        assert_eq!(expr2.result, 1.0);
816        assert!(expr2.operand1.is_some());
817        assert_eq!(expr2.operand1.unwrap().result, 1.0);
818        assert_eq!(expr2.operation, Operation::ReLU);
819        assert!(expr2.operand2.is_none());
820    }
821
822    #[test]
823    fn test_pow() {
824        let expr = Expr::new_leaf(2.0, "x");
825        let expr2 = Expr::new_leaf(3.0, "y");
826        let result = expr.pow(expr2, "x^y");
827
828        assert_eq!(result.result, 8.0);
829        assert!(result.operand1.is_some());
830        assert_eq!(result.operand1.unwrap().result, 2.0);
831        assert_eq!(result.operation, Operation::Pow);
832        
833        assert!(result.operand2.is_some());
834        assert_eq!(result.operand2.unwrap().result, 3.0);
835    }
836
837    #[test]
838    fn test_add() {
839        let expr = Expr::new_leaf(1.0, "x");
840        let expr2 = Expr::new_leaf(2.0, "y");
841        let expr3 = expr + expr2;
842
843        assert_eq!(expr3.result, 3.0);
844        assert!(expr3.operand1.is_some());
845        assert_eq!(expr3.operand1.unwrap().result, 1.0);
846        assert!(expr3.operand2.is_some());
847        assert_eq!(expr3.operand2.unwrap().result, 2.0);
848        assert_eq!(expr3.operation, Operation::Add);
849        assert_eq!(expr3.name, "(x + y)");
850    }
851
852    #[test]
853    fn test_add_f64() {
854        let expr = Expr::new_leaf(1.0, "x");
855        let expr2 = expr + 2.0;
856
857        assert_eq!(expr2.result, 3.0);
858        assert!(expr2.operand1.is_some());
859        assert_eq!(expr2.operand1.unwrap().result, 1.0);
860        assert!(expr2.operand2.is_some());
861        assert_eq!(expr2.operand2.unwrap().result, 2.0);
862        assert_eq!(expr2.operation, Operation::Add);
863        assert_eq!(expr2.name, "(x + 2)");
864    }
865
866    #[test]
867    fn test_add_f64_expr() {
868        let expr = Expr::new_leaf(1.0, "x");
869        let expr2 = 2.0 + expr;
870
871        assert_eq!(expr2.result, 3.0);
872        assert!(expr2.operand1.is_some());
873        assert_eq!(expr2.operand1.unwrap().result, 2.0);
874        assert!(expr2.operand2.is_some());
875        assert_eq!(expr2.operand2.unwrap().result, 1.0);
876        assert_eq!(expr2.operation, Operation::Add);
877        assert_eq!(expr2.name, "(2 + x)");
878    }
879
880    #[test]
881    fn test_mul() {
882        let expr = Expr::new_leaf(2.0, "x");
883        let expr2 = Expr::new_leaf(3.0, "y");
884        let expr3 = expr * expr2;
885
886        assert_eq!(expr3.result, 6.0);
887        assert!(expr3.operand1.is_some());
888        assert_eq!(expr3.operand1.unwrap().result, 2.0);
889        assert!(expr3.operand2.is_some());
890        assert_eq!(expr3.operand2.unwrap().result, 3.0);
891        assert_eq!(expr3.operation, Operation::Mul);
892        assert_eq!(expr3.name, "(x * y)");
893    }
894
895    #[test]
896    fn test_mul_f64() {
897        let expr = Expr::new_leaf(2.0, "x");
898        let expr2 = expr * 3.0;
899
900        assert_eq!(expr2.result, 6.0);
901        assert!(expr2.operand1.is_some());
902        assert_eq!(expr2.operand1.unwrap().result, 2.0);
903        assert!(expr2.operand2.is_some());
904        assert_eq!(expr2.operand2.unwrap().result, 3.0);
905        assert_eq!(expr2.operation, Operation::Mul);
906        assert_eq!(expr2.name, "(x * 3)");
907    }
908
909    #[test]
910    fn test_mul_f64_expr() {
911        let expr = Expr::new_leaf(2.0, "x");
912        let expr2 = 3.0 * expr;
913
914        assert_eq!(expr2.result, 6.0);
915        assert!(expr2.operand1.is_some());
916        assert_eq!(expr2.operand1.unwrap().result, 3.0);
917        assert!(expr2.operand2.is_some());
918        assert_eq!(expr2.operand2.unwrap().result, 2.0);
919        assert_eq!(expr2.operation, Operation::Mul);
920        assert_eq!(expr2.name, "(3 * x)");
921    }
922
923    #[test]
924    fn test_sub() {
925        let expr = Expr::new_leaf(2.0, "x");
926        let expr2 = Expr::new_leaf(3.0, "y");
927        let expr3 = expr - expr2;
928
929        assert_eq!(expr3.result, -1.0);
930        assert!(expr3.operand1.is_some());
931        assert_eq!(expr3.operand1.unwrap().result, 2.0);
932        assert!(expr3.operand2.is_some());
933        assert_eq!(expr3.operand2.unwrap().result, 3.0);
934        assert_eq!(expr3.operation, Operation::Sub);
935        assert_eq!(expr3.name, "(x - y)");
936    }
937
938    #[test]
939    fn test_sub_f64() {
940        let expr = Expr::new_leaf(2.0, "x");
941        let expr2 = expr - 3.0;
942
943        assert_eq!(expr2.result, -1.0);
944        assert!(expr2.operand1.is_some());
945        assert_eq!(expr2.operand1.unwrap().result, 2.0);
946        assert!(expr2.operand2.is_some());
947        assert_eq!(expr2.operand2.unwrap().result, 3.0);
948        assert_eq!(expr2.operation, Operation::Sub);
949        assert_eq!(expr2.name, "(x - 3)");
950    }
951
952    #[test]
953    fn test_sub_f64_expr() {
954        let expr = Expr::new_leaf(2.0, "x");
955        let expr2 = 3.0 - expr;
956
957        assert_eq!(expr2.result, 1.0);
958        assert!(expr2.operand1.is_some());
959        assert_eq!(expr2.operand1.unwrap().result, 3.0);
960        assert!(expr2.operand2.is_some());
961        assert_eq!(expr2.operand2.unwrap().result, 2.0);
962        assert_eq!(expr2.operation, Operation::Sub);
963        assert_eq!(expr2.name, "(3 - x)");
964    }
965
966    #[test]
967    fn test_div() {
968        let expr = Expr::new_leaf(6.0, "x");
969        let expr2 = Expr::new_leaf(3.0, "y");
970        let expr3 = expr / expr2;
971
972        assert_eq!(expr3.result, 2.0);
973        assert!(expr3.operand1.is_some());
974        assert_eq!(expr3.operand1.unwrap().result, 6.0);
975        assert!(expr3.operand2.is_some());
976        assert_eq!(expr3.operand2.unwrap().result, 3.0);
977        assert_eq!(expr3.operation, Operation::Div);
978        assert_eq!(expr3.name, "(x / y)");
979    }
980
981    #[test]
982    fn test_div_f64() {
983        let expr = Expr::new_leaf(6.0, "x");
984        let expr2 = expr / 3.0;
985
986        assert_eq!(expr2.result, 2.0);
987        assert!(expr2.operand1.is_some());
988        assert_eq!(expr2.operand1.unwrap().result, 6.0);
989        assert!(expr2.operand2.is_some());
990        assert_eq!(expr2.operand2.unwrap().result, 3.0);
991        assert_eq!(expr2.operation, Operation::Div);
992        assert_eq!(expr2.name, "(x / 3)");
993    }
994
995    #[test]
996    fn test_log() {
997        let expr = Expr::new_leaf(2.0, "x");
998        let expr2 = expr.log("log(x)");
999
1000        assert_eq!(expr2.result, 0.6931471805599453);
1001        assert!(expr2.operand1.is_some());
1002        assert_eq!(expr2.operand1.unwrap().result, 2.0);
1003        assert_eq!(expr2.operation, Operation::Log);
1004        assert!(expr2.operand2.is_none());
1005    }
1006
1007    #[test]
1008    fn test_neg() {
1009        let expr = Expr::new_leaf(2.0, "x");
1010        let expr2 = expr.neg("neg(x)");
1011
1012        assert_eq!(expr2.result, -2.0);
1013        assert!(expr2.operand1.is_some());
1014        assert_eq!(expr2.operand1.unwrap().result, 2.0);
1015        assert_eq!(expr2.operation, Operation::Neg);
1016        assert!(expr2.operand2.is_none());
1017    }
1018
1019    #[test]
1020    fn test_backpropagation_add() {
1021        let operand1 = Expr::new_leaf(1.0, "x");
1022        let operand2 = Expr::new_leaf(2.0, "y");
1023        let mut expr3 = operand1 + operand2;
1024
1025        expr3.learn(1e-09);
1026
1027        let operand1 = expr3.operand1.unwrap();
1028        let operand2 = expr3.operand2.unwrap();
1029        assert_eq!(operand1.grad, 1.0);
1030        assert_eq!(operand2.grad, 1.0);
1031    }
1032
1033    #[test]
1034    fn test_backpropagation_sub() {
1035        let operand1 = Expr::new_leaf(1.0, "x");
1036        let operand2 = Expr::new_leaf(2.0, "y");
1037        let mut expr3 = operand1 - operand2;
1038
1039        expr3.learn(1e-09);
1040
1041        let operand1 = expr3.operand1.unwrap();
1042        let operand2 = expr3.operand2.unwrap();
1043        assert_eq!(operand1.grad, 1.0);
1044        assert_eq!(operand2.grad, -1.0);
1045    }
1046
1047    #[test]
1048    fn test_backpropagation_mul() {
1049        let operand1 = Expr::new_leaf(3.0, "x");
1050        let operand2 = Expr::new_leaf(4.0, "y");
1051        let mut expr3 = operand1 * operand2;
1052
1053        expr3.learn(1e-09);
1054
1055        let operand1 = expr3.operand1.unwrap();
1056        let operand2 = expr3.operand2.unwrap();
1057        assert_eq!(operand1.grad, 4.0);
1058        assert_eq!(operand2.grad, 3.0);
1059    }
1060
1061    #[test]
1062    fn test_backpropagation_div() {
1063        let operand1 = Expr::new_leaf(3.0, "x");
1064        let operand2 = Expr::new_leaf(4.0, "y");
1065        let mut expr3 = operand1 / operand2;
1066
1067        expr3.learn(1e-09);
1068
1069        let operand1 = expr3.operand1.unwrap();
1070        let operand2 = expr3.operand2.unwrap();
1071        assert_eq!(operand1.grad, 0.25);
1072        assert_eq!(operand2.grad, -0.1875);
1073    }
1074
1075    #[test]
1076    fn test_backpropagation_tanh() {
1077        let operand1 = Expr::new_leaf(0.0, "x");
1078        let mut expr2 = operand1.tanh("tanh(x)");
1079
1080        expr2.learn(1e-09);
1081
1082        let operand1 = expr2.operand1.unwrap();
1083        assert_eq!(operand1.grad, 1.0);
1084    }
1085
1086    #[test]
1087    fn test_backpropagation_relu() {
1088        let operand1 = Expr::new_leaf(-1.0, "x");
1089        let mut expr2 = operand1.relu("relu(x)");
1090
1091        expr2.learn(1e-09);
1092
1093        let operand1 = expr2.operand1.unwrap();
1094        assert_eq!(operand1.grad, 0.0);
1095    }
1096
1097    #[test]
1098    fn test_backpropagation_exp() {
1099        let operand1 = Expr::new_leaf(0.0, "x");
1100        let mut expr2 = operand1.exp("exp(x)");
1101
1102        expr2.learn(1e-09);
1103
1104        let operand1 = expr2.operand1.unwrap();
1105        assert_eq!(operand1.grad, 1.0);
1106    }
1107
1108    #[test]
1109    fn test_backpropagation_pow() {
1110        let operand1 = Expr::new_leaf(2.0, "x");
1111        let operand2 = Expr::new_leaf(3.0, "y");
1112        let mut expr3 = operand1.pow(operand2, "x^y");
1113
1114        expr3.learn(1e-09);
1115
1116        let operand1 = expr3.operand1.unwrap();
1117        let operand2 = expr3.operand2.unwrap();
1118        assert_eq!(operand1.grad, 12.0);
1119        assert_eq!(operand2.grad, 5.545177444479562);
1120    }
1121
1122    #[test]
1123    fn test_backpropagation_mixed_tree() {
1124        let operand1 = Expr::new_leaf(1.0, "x");
1125        let operand2 = Expr::new_leaf(2.0, "y");
1126        let expr3 = operand1 + operand2;
1127        let mut expr4 = expr3.tanh("tanh(x + y)");
1128
1129        expr4.learn(1e-09);
1130
1131        let expr3 = expr4.operand1.unwrap();
1132        let operand1 = expr3.operand1.unwrap();
1133        let operand2 = expr3.operand2.unwrap();
1134
1135        assert_eq!(expr3.grad, 0.009866037165440211);
1136        assert_eq!(operand1.grad, 0.009866037165440211);
1137        assert_eq!(operand2.grad, 0.009866037165440211);
1138    }
1139
1140    #[test]
1141    fn test_backpropagation_karpathys_example() {
1142        let x1 = Expr::new_leaf(2.0, "x1");
1143        let x2 = Expr::new_leaf(0.0, "x2");
1144        let w1 = Expr::new_leaf(-3.0, "w1");
1145        let w2 = Expr::new_leaf(1.0, "w2");
1146        let b = Expr::new_leaf(6.8813735870195432, "b");
1147
1148        let x1w1 = x1 * w1;
1149        let x2w2 = x2 * w2;
1150        let x1w1_x2w2 = x1w1 + x2w2;
1151        let n = x1w1_x2w2 + b;
1152        let mut o = n.tanh("tanh(n)");
1153
1154        o.learn(1e-09);
1155
1156        assert_eq!(o.operation, Operation::Tanh);
1157        assert_eq!(o.grad, 1.0);
1158
1159        let n = o.operand1.unwrap();
1160        assert_eq!(n.operation, Operation::Add);
1161        assert_float_eq(n.grad, 0.5);
1162
1163        let x1w1_x2w2 = n.operand1.unwrap();
1164        assert_eq!(x1w1_x2w2.operation, Operation::Add);
1165        assert_float_eq(x1w1_x2w2.grad, 0.5);
1166
1167        let b = n.operand2.unwrap();
1168        assert_eq!(b.operation, Operation::None);
1169        assert_float_eq(b.grad, 0.5);
1170
1171        let x1w1 = x1w1_x2w2.operand1.unwrap();
1172        assert_eq!(x1w1.operation, Operation::Mul);
1173        assert_float_eq(x1w1.grad, 0.5);
1174
1175        let x2w2 = x1w1_x2w2.operand2.unwrap();
1176        assert_eq!(x2w2.operation, Operation::Mul);
1177        assert_float_eq(x2w2.grad, 0.5);
1178
1179        let x1 = x1w1.operand1.unwrap();
1180        assert_eq!(x1.operation, Operation::None);
1181        assert_float_eq(x1.grad, -1.5);
1182
1183        let w1 = x1w1.operand2.unwrap();
1184        assert_eq!(w1.operation, Operation::None);
1185        assert_float_eq(w1.grad, 1.0);
1186
1187        let x2 = x2w2.operand1.unwrap();
1188        assert_eq!(x2.operation, Operation::None);
1189        assert_float_eq(x2.grad, 0.5);
1190
1191        let w2 = x2w2.operand2.unwrap();
1192        assert_eq!(w2.operation, Operation::None);
1193        assert_float_eq(w2.grad, 0.0);
1194    }
1195
1196    #[test]
1197    fn test_learn_simple() {
1198        let mut expr = Expr::new_leaf(1.0, "x");
1199        expr.learn(1e-01);
1200
1201        assert_float_eq(expr.result, 0.9);
1202    }
1203
1204    #[test]
1205    fn test_learn_skips_non_learnable() {
1206        let mut expr = Expr::new_leaf(1.0, "x");
1207        expr.is_learnable = false;
1208        expr.learn(1e-01);
1209
1210        assert_float_eq(expr.result, 1.0);
1211    }
1212
1213    #[test]
1214    fn test_find_simple() {
1215        let expr = Expr::new_leaf(1.0, "x");
1216        let expr2 = expr.tanh("tanh(x)");
1217
1218        let found = expr2.find("x");
1219        assert!(found.is_some());
1220        assert_eq!(found.unwrap().name, "x");
1221    }
1222
1223    #[test]
1224    fn test_find_not_found() {
1225        let expr = Expr::new_leaf(1.0, "x");
1226        let expr2 = expr.tanh("tanh(x)");
1227
1228        let found = expr2.find("y");
1229        assert!(found.is_none());
1230    }
1231
1232    #[test]
1233    fn test_sum_iterator() {
1234        let expr = Expr::new_leaf(1.0, "x");
1235        let expr2 = Expr::new_leaf(2.0, "y");
1236        let expr3 = Expr::new_leaf(3.0, "z");
1237
1238        let sum: Expr = vec![expr, expr2, expr3].into_iter().sum::<Expr>();
1239        assert_eq!(sum.result, 6.0);
1240    }
1241
1242    #[test]
1243    fn test_find_after_clone() {
1244        let expr = Expr::new_leaf(1.0, "x");
1245        let expr2 = expr.tanh("tanh(x)");
1246        let expr2_clone = expr2.clone();
1247
1248        let found = expr2_clone.find("x");
1249        assert!(found.is_some());
1250        assert_eq!(found.unwrap().name, "x");
1251    }
1252}