alpha_micrograd_rust/
value.rs

1//! A simple library for creating and backpropagating through expression trees.
2//! 
3//! This package includes the following elements to construct expression trees:
4//! - [`Expr`]: a node in the expression tree
5#![deny(missing_docs)]
6use std::ops::{Add, Div, Mul, Sub};
7use std::iter::Sum;
8
9#[derive(Debug, Clone, PartialEq)]
10enum Operation {
11    None,
12    Add,
13    Sub,
14    Mul,
15    Div,
16    Tanh,
17    Exp,
18    Pow,
19    ReLU,
20    Log,
21    Neg,
22}
23
24impl Operation {
25    fn assert_is_type(&self, expr_type: ExprType) {
26        match self {
27            Operation::None => assert_eq!(expr_type, ExprType::Leaf),
28            Operation::Tanh | Operation::Exp | Operation::ReLU | Operation::Log | Operation::Neg => assert_eq!(expr_type, ExprType::Unary),
29            _ => assert_eq!(expr_type, ExprType::Binary),
30        }
31    }
32}
33
34#[derive(Debug, PartialEq)]
35enum ExprType {
36    Leaf,
37    Unary,
38    Binary,
39}
40
41/// Expression representing a node in a calculation graph.
42/// 
43/// This struct represents a node in a calculation graph. It can be a leaf node, a unary operation or a binary operation.
44/// 
45/// A leaf node holds a value, which is the one that is used in the calculation.
46/// 
47/// A unary expression is the result of applying a unary operation to another expression. For example, the result of applying the `tanh` operation to a leaf node.
48/// 
49/// A binary expression is the result of applying a binary operation to two other expressions. For example, the result of adding two leaf nodes.
50#[derive(Debug, Clone)]
51pub struct Expr {
52    operand1: Option<Box<Expr>>,
53    operand2: Option<Box<Expr>>,
54    operation: Operation,
55    /// The numeric result of the expression, as result of applying the operation to the operands.
56    pub result: f64,
57    /// Whether the expression is learnable or not. Only learnable [`Expr`] will have their values updated during backpropagation (learning).
58    pub is_learnable: bool,
59    grad: f64,
60    /// The name of the expression, used to identify it in the calculation graph.
61    pub name: String,
62}
63
64impl Expr {
65    /// Creates a new leaf expression with the given value.
66    /// 
67    /// Example:
68    /// ```rust
69    /// use alpha_micrograd_rust::value::Expr;
70    /// 
71    /// let expr = Expr::new_leaf(1.0, "x");
72    /// ```
73    pub fn new_leaf(value: f64, name: &str) -> Expr {
74        Expr {
75            operand1: None,
76            operand2: None,
77            operation: Operation::None,
78            result: value,
79            is_learnable: true,
80            grad: 0.0,
81            name: name.to_string(),
82        }
83    }
84
85    fn expr_type(&self) -> ExprType {
86        match self.operation {
87            Operation::None => ExprType::Leaf,
88            Operation::Tanh | Operation::Exp | Operation::ReLU | Operation::Log | Operation::Neg => ExprType::Unary,
89            _ => ExprType::Binary,
90        }
91    }
92
93    fn new_unary(operand: Expr, operation: Operation, result: f64, name: &str) -> Expr {
94        operation.assert_is_type(ExprType::Unary);
95        Expr {
96            operand1: Some(Box::new(operand)),
97            operand2: None,
98            operation,
99            result,
100            is_learnable: false,
101            grad: 0.0,
102            name: name.to_string(),
103        }
104    }
105
106    fn new_binary(operand1: Expr, operand2: Expr, operation: Operation, result: f64, name: &str) -> Expr {
107        operation.assert_is_type(ExprType::Binary);
108        Expr {
109            operand1: Some(Box::new(operand1)),
110            operand2: Some(Box::new(operand2)),
111            operation,
112            result,
113            is_learnable: false,
114            grad: 0.0,
115            name: name.to_string(),
116        }
117    }
118
119    /// Applies the hyperbolic tangent function to the expression and returns it as a new expression.
120    /// 
121    /// Example:
122    /// ```rust
123    /// use alpha_micrograd_rust::value::Expr;
124    /// 
125    /// let expr = Expr::new_leaf(1.0, "x");
126    /// let expr2 = expr.tanh("tanh");
127    /// 
128    /// println!("Result: {}", expr2.result); // 0.7615941559557649
129    /// ```
130    pub fn tanh(self, name: &str) -> Expr {
131        let result = self.result.tanh();
132        Expr::new_unary(self, Operation::Tanh, result, name)
133    }
134
135    /// Applies the rectified linear unit function to the expression and returns it as a new expression.
136    /// 
137    /// Example:
138    /// ```rust
139    /// use alpha_micrograd_rust::value::Expr;
140    /// 
141    /// let expr = Expr::new_leaf(-1.0, "x");
142    /// let expr2 = expr.relu("relu");
143    /// 
144    /// println!("Result: {}", expr2.result); // 0.0
145    /// ```
146    pub fn relu(self, name: &str) -> Expr {
147        let result = self.result.max(0.0);
148        Expr::new_unary(self, Operation::ReLU, result, name)
149    }
150
151    /// Applies the exponential function (e^x) to the expression and returns it as a new expression.
152    /// 
153    /// Example:
154    /// ```rust
155    /// use alpha_micrograd_rust::value::Expr;
156    /// 
157    /// let expr = Expr::new_leaf(1.0, "x");
158    /// let expr2 = expr.exp("exp");
159    /// 
160    /// println!("Result: {}", expr2.result); // 2.718281828459045
161    /// ```
162    pub fn exp(self, name: &str) -> Expr {
163        let result = self.result.exp();
164        Expr::new_unary(self, Operation::Exp, result, name)
165    }
166
167    /// Raises the expression to the power of the given exponent (expression) and returns it as a new expression.
168    /// 
169    /// Example:
170    /// ```rust
171    /// use alpha_micrograd_rust::value::Expr;
172    /// 
173    /// let expr = Expr::new_leaf(2.0, "x");
174    /// let exponent = Expr::new_leaf(3.0, "y");
175    /// let result = expr.pow(exponent, "x^y");
176    /// 
177    /// println!("Result: {}", result.result); // 8.0
178    /// ```
179    pub fn pow(self, exponent: Expr, name: &str) -> Expr {
180        let result = self.result.powf(exponent.result);
181        Expr::new_binary(self, exponent, Operation::Pow, result, name)
182    }
183
184    /// Applies the natural logarithm function to the expression and returns it as a new expression.
185    /// 
186    /// Example:
187    /// ```rust
188    /// use alpha_micrograd_rust::value::Expr;
189    /// 
190    /// let expr = Expr::new_leaf(2.0, "x");
191    /// let expr2 = expr.log("log");
192    /// 
193    /// println!("Result: {}", expr2.result); // 0.6931471805599453
194    /// ```
195    pub fn log(self, name: &str) -> Expr {
196        let result = self.result.ln();
197        Expr::new_unary(self, Operation::Log, result, name)
198    }
199
200    /// Negates the expression and returns it as a new expression.
201    /// 
202    /// Example:
203    /// ```rust
204    /// use alpha_micrograd_rust::value::Expr;
205    /// 
206    /// let expr = Expr::new_leaf(1.0, "x");
207    /// let expr2 = expr.neg("neg");
208    /// 
209    /// println!("Result: {}", expr2.result); // -1.0
210    /// ```
211    pub fn neg(self, name: &str) -> Expr {
212        let result = -self.result;
213        Expr::new_unary(self, Operation::Neg, result, name)
214    }
215
216    /// Recalculates the value of the expression recursively, from new values of the operands.
217    /// 
218    /// Usually will be used after a call to [`Expr::learn`], where the gradients have been calculated and
219    /// the internal values of the expression tree have been updated.
220    /// 
221    /// Example:
222    /// ```rust
223    /// use alpha_micrograd_rust::value::Expr;
224    /// 
225    /// let expr = Expr::new_leaf(1.0, "x");
226    /// let mut expr2 = expr.tanh("tanh(x)");
227    /// expr2.learn(1e-09);
228    /// expr2.recalculate();
229    /// ```
230    pub fn recalculate(&mut self) {
231        match self.expr_type() {
232            ExprType::Leaf => {}
233            ExprType::Unary => {
234                let operand1 = self.operand1.as_mut().expect("Unary expression did not have an operand");
235                operand1.recalculate();
236
237                self.result = match self.operation {
238                    Operation::Tanh => operand1.result.tanh(),
239                    Operation::Exp => operand1.result.exp(),
240                    Operation::ReLU => operand1.result.max(0.0),
241                    _ => panic!("Invalid unary operation {:?}", self.operation),
242                };
243            }
244            ExprType::Binary => {
245                let operand1 = self.operand1.as_mut().expect("Binary expression did not have an operand");
246                let operand2 = self.operand2.as_mut().expect("Binary expression did not have a second operand");
247
248                operand1.recalculate();
249                operand2.recalculate();
250
251                self.result = match self.operation {
252                    Operation::Add => operand1.result + operand2.result,
253                    Operation::Sub => operand1.result - operand2.result,
254                    Operation::Mul => operand1.result * operand2.result,
255                    Operation::Div => operand1.result / operand2.result,
256                    Operation::Pow => operand1.result.powf(operand2.result),
257                    _ => panic!("Invalid binary operation: {:?}", self.operation),
258                };
259            }
260        }
261    }
262
263    /// Applies backpropagation to the expression, updating the values of the
264    /// gradients and the expression itself.
265    /// 
266    /// This method will change the gradients based on the gradient of the last
267    /// expression in the calculation graph.
268    /// 
269    /// Example:
270    /// 
271    /// ```rust
272    /// use alpha_micrograd_rust::value::Expr;
273    /// 
274    /// let expr = Expr::new_leaf(1.0, "x");
275    /// let mut expr2 = expr.tanh("tanh(x)");
276    /// expr2.learn(1e-09);
277    /// ```
278    /// 
279    /// After adjusting the gradients, the method will update the values of the
280    /// individual expression tree nodes to minimize the loss function.
281    /// 
282    /// In order to get a new calculation of the expression tree, you'll need to call
283    /// [`Expr::recalculate`] after calling [`Expr::learn`].
284    pub fn learn(&mut self, learning_rate: f64) {
285        self.grad = 1.0;
286        self.learn_internal(learning_rate);
287    }
288
289    fn learn_internal(&mut self, learning_rate: f64) {
290        match self.expr_type() {
291            ExprType::Leaf => {
292                // leaves have their gradient set externally by other nodes in the tree
293                // leaves can be learnable, in which case we update the value
294                if self.is_learnable {
295                    self.result -= learning_rate * self.grad;
296                 }
297            }
298            ExprType::Unary => {
299                let operand1 = self.operand1.as_mut().expect("Unary expression did not have an operand");
300
301                match self.operation {
302                    Operation::Tanh => {
303                        let tanh_grad = 1.0 - (self.result * self.result);
304                        operand1.grad = self.grad * tanh_grad;
305                    }
306                    Operation::Exp => {
307                        operand1.grad = self.grad * self.result;
308                    }
309                    Operation::ReLU => {
310                        operand1.grad = self.grad * if self.result > 0.0 { 1.0 } else { 0.0 };
311                    }
312                    Operation::Log => {
313                        operand1.grad = self.grad / operand1.result;
314                    }
315                    Operation::Neg => {
316                        operand1.grad = -self.grad;
317                    }
318                    _ => panic!("Invalid unary operation {:?}", self.operation),
319                }
320
321                operand1.learn_internal(learning_rate);
322            }
323            ExprType::Binary => {
324                let operand1 = self.operand1.as_mut().expect("Binary expression did not have an operand");
325                let operand2 = self.operand2.as_mut().expect("Binary expression did not have a second operand");
326
327                match self.operation {
328                    Operation::Add => {
329                        operand1.grad = self.grad;
330                        operand2.grad = self.grad;
331                    }
332                    Operation::Sub => {
333                        operand1.grad = self.grad;
334                        operand2.grad = -self.grad;
335                    }
336                    Operation::Mul => {
337                        let operand2_result = operand2.result;
338                        let operand1_result = operand1.result;
339
340                        operand1.grad = self.grad * operand2_result;
341                        operand2.grad = self.grad * operand1_result;
342                    }
343                    Operation::Div => {
344                        let operand2_result = operand2.result;
345                        let operand1_result = operand1.result;
346
347                        operand1.grad = self.grad / operand2_result;
348                        operand2.grad = -self.grad * operand1_result / (operand2_result * operand2_result);
349                    }
350                    Operation::Pow => {
351                        let exponent = operand2.result;
352                        let base = operand1.result;
353
354                        operand1.grad = self.grad * exponent * base.powf(exponent - 1.0);
355                        operand2.grad = self.grad * base.powf(exponent) * base.ln();
356                    }
357                    _ => panic!("Invalid binary operation: {:?}", self.operation),
358                }
359
360                operand1.learn_internal(learning_rate);
361                operand2.learn_internal(learning_rate);
362            }
363        }
364    }
365
366    /// Finds a node in the expression tree by its name.
367    /// 
368    /// This method will search the expression tree for a node with the given name.
369    /// If the node is not found, it will return [None].
370    /// 
371    /// Example:
372    /// ```rust
373    /// use alpha_micrograd_rust::value::Expr;
374    /// 
375    /// let expr = Expr::new_leaf(1.0, "x");
376    /// let expr2 = expr.tanh("tanh(x)");
377    /// let original = expr2.find("x");
378    /// 
379    /// assert_eq!(original.expect("Could not find x").result, 1.0);
380    /// ```
381    pub fn find(&self, name: &str) -> Option<&Expr> {
382        if self.name == name {
383            return Some(self);
384        }
385
386        match self.expr_type() {
387            ExprType::Leaf => None,
388            ExprType::Unary => {
389                let operand1 = self.operand1.as_ref().expect("Unary expression did not have an operand");
390                operand1.find(name)
391            }
392            ExprType::Binary => {
393                let operand1 = self.operand1.as_ref().expect("Binary expression did not have an operand");
394                let operand2 = self.operand2.as_ref().expect("Binary expression did not have a second operand");
395
396                let result = operand1.find(name);
397                if result.is_some() {
398                    return result;
399                }
400
401                operand2.find(name)
402            }
403        }
404    }
405}
406
407/// Implements the [`Add`] trait for the [`Expr`] struct.
408/// 
409/// This implementation allows the addition of two [`Expr`] objects.
410/// 
411/// Example:
412/// ```rust
413/// use alpha_micrograd_rust::value::Expr;
414/// 
415/// let expr = Expr::new_leaf(1.0, "x");
416/// let expr2 = Expr::new_leaf(2.0, "y");
417/// 
418/// let result = expr + expr2;
419/// 
420/// println!("Result: {}", result.result); // 3.0
421/// ```
422impl Add for Expr {
423    type Output = Expr;
424
425    fn add(self, other: Expr) -> Expr {
426        let result = self.result + other.result;
427        let name = &format!("({} + {})", self.name, other.name);
428        Expr::new_binary(self, other, Operation::Add, result, name)
429    }
430}
431
432/// Implements the [`Add`] trait for the [`Expr`] struct, when the right operand is a [`f64`].
433/// 
434/// This implementation allows the addition of an [`Expr`] object and a [`f64`] value.
435/// 
436/// Example:
437/// ```rust
438/// use alpha_micrograd_rust::value::Expr;
439/// 
440/// let expr = Expr::new_leaf(1.0, "x");
441/// let result = expr + 2.0;
442/// 
443/// println!("Result: {}", result.result); // 3.0
444/// ```
445impl Add<f64> for Expr {
446    type Output = Expr;
447
448    fn add(self, other: f64) -> Expr {
449        let operand2 = Expr::new_leaf(other, &other.to_string());
450        self + operand2
451    }
452}
453
454/// Implements the [`Add`] trait for the [`f64`] type, when the right operand is an [`Expr`].
455/// 
456/// This implementation allows the addition of a [`f64`] value and an [`Expr`] object.
457/// 
458/// Example:
459/// ```rust
460/// use alpha_micrograd_rust::value::Expr;
461/// 
462/// let expr = Expr::new_leaf(1.0, "x");
463/// let result = 2.0 + expr;
464/// 
465/// println!("Result: {}", result.result); // 3.0
466/// ```
467impl Add<Expr> for f64 {
468    type Output = Expr;
469
470    fn add(self, other: Expr) -> Expr {
471        let operand1 = Expr::new_leaf(self, &self.to_string());
472        operand1 + other
473    }
474}
475
476/// Implements the [`Mul`] trait for the [`Expr`] struct.
477/// 
478/// This implementation allows the multiplication of two [`Expr`] objects.
479/// 
480/// Example:
481/// 
482/// ```rust
483/// use alpha_micrograd_rust::value::Expr;
484/// 
485/// let expr = Expr::new_leaf(1.0, "x");
486/// let expr2 = Expr::new_leaf(2.0, "y");
487/// 
488/// let result = expr * expr2;
489/// 
490/// println!("Result: {}", result.result); // 2.0
491/// ```
492impl Mul for Expr {
493    type Output = Expr;
494
495    fn mul(self, other: Expr) -> Expr {
496        let result = self.result * other.result;
497        let name = &format!("({} * {})", self.name, other.name);
498        Expr::new_binary(self, other, Operation::Mul, result, name)
499    }
500}
501
502/// Implements the [`Mul`] trait for the [`Expr`] struct, when the right operand is a [`f64`].
503/// 
504/// This implementation allows the multiplication of an [`Expr`] object and a [`f64`] value.
505/// 
506/// Example:
507/// 
508/// ```rust
509/// use alpha_micrograd_rust::value::Expr;
510/// 
511/// let expr = Expr::new_leaf(1.0, "x");
512/// let result = expr * 2.0;
513/// 
514/// println!("Result: {}", result.result); // 2.0
515/// ```
516impl Mul<f64> for Expr {
517    type Output = Expr;
518
519    fn mul(self, other: f64) -> Expr {
520        let operand2 = Expr::new_leaf(other, &other.to_string());
521        self * operand2
522    }
523}
524
525/// Implements the [`Mul`] trait for the [`f64`] type, when the right operand is an [`Expr`].
526/// 
527/// This implementation allows the multiplication of a [`f64`] value and an [`Expr`] object.
528/// 
529/// Example:
530/// 
531/// ```rust
532/// use alpha_micrograd_rust::value::Expr;
533/// 
534/// let expr = Expr::new_leaf(1.0, "x");
535/// let result = 2.0 * expr;
536/// 
537/// println!("Result: {}", result.result); // 2.0
538/// ```
539impl Mul<Expr> for f64 {
540    type Output = Expr;
541
542    fn mul(self, other: Expr) -> Expr {
543        let operand1 = Expr::new_leaf(self, &self.to_string());
544        operand1 * other
545    }
546}
547
548/// Implements the [`Sub`] trait for the [`Expr`] struct.
549/// 
550/// This implementation allows the subtraction of two [`Expr`] objects.
551/// 
552/// Example:
553/// 
554/// ```rust
555/// use alpha_micrograd_rust::value::Expr;
556/// 
557/// let expr = Expr::new_leaf(1.0, "x");
558/// let expr2 = Expr::new_leaf(2.0, "y");
559/// 
560/// let result = expr - expr2;
561/// 
562/// println!("Result: {}", result.result); // -1.0
563/// ```
564impl Sub for Expr {
565    type Output = Expr;
566
567    fn sub(self, other: Expr) -> Expr {
568        let result = self.result - other.result;
569        let name = &format!("({} - {})", self.name, other.name);
570        Expr::new_binary(self, other, Operation::Sub, result, name)
571    }
572}
573
574/// Implements the [`Sub`] trait for the [`Expr`] struct, when the right operand is a [`f64`].
575/// 
576/// This implementation allows the subtraction of an [`Expr`] object and a [`f64`] value.
577/// 
578/// Example:
579/// 
580/// ```rust
581/// use alpha_micrograd_rust::value::Expr;
582/// 
583/// let expr = Expr::new_leaf(1.0, "x");
584/// let result = expr - 2.0;
585/// 
586/// println!("Result: {}", result.result); // -1.0
587/// ```
588impl Sub<f64> for Expr {
589    type Output = Expr;
590
591    fn sub(self, other: f64) -> Expr {
592        let operand2 = Expr::new_leaf(other, &other.to_string());
593        self - operand2
594    }
595}
596
597/// Implements the [`Sub`] trait for the [`f64`] type, when the right operand is an [`Expr`].
598/// 
599/// This implementation allows the subtraction of a [`f64`] value and an [`Expr`] object.
600/// 
601/// Example:
602/// 
603/// ```rust
604/// use alpha_micrograd_rust::value::Expr;
605/// 
606/// let expr = Expr::new_leaf(1.0, "x");
607/// let result = 2.0 - expr;
608/// 
609/// println!("Result: {}", result.result); // 1.0
610/// ```
611impl Sub<Expr> for f64 {
612    type Output = Expr;
613
614    fn sub(self, other: Expr) -> Expr {
615        let operand1 = Expr::new_leaf(self, &self.to_string());
616        operand1 - other
617    }
618}
619
620/// Implements the [`Div`] trait for the [`Expr`] struct.
621/// 
622/// This implementation allows the division of two [`Expr`] objects.
623/// 
624/// Example:
625/// 
626/// ```rust
627/// use alpha_micrograd_rust::value::Expr;
628/// 
629/// let expr = Expr::new_leaf(1.0, "x");
630/// let expr2 = Expr::new_leaf(2.0, "y");
631/// 
632/// let result = expr / expr2;
633/// 
634/// println!("Result: {}", result.result); // 0.5
635/// ```
636impl Div for Expr {
637    type Output = Expr;
638
639    fn div(self, other: Expr) -> Expr {
640        let result = self.result / other.result;
641        let name = &format!("({} / {})", self.name, other.name);
642        Expr::new_binary(self, other, Operation::Div, result, name)
643    }
644}
645
646/// Implements the [`Div`] trait for the [`Expr`] struct, when the right operand is a [`f64`].
647/// 
648/// This implementation allows the division of an [`Expr`] object and a [`f64`] value.
649/// 
650/// Example:
651/// 
652/// ```rust
653/// use alpha_micrograd_rust::value::Expr;
654/// 
655/// let expr = Expr::new_leaf(1.0, "x");
656/// let result = expr / 2.0;
657/// 
658/// println!("Result: {}", result.result); // 0.5
659/// ```
660impl Div<f64> for Expr {
661    type Output = Expr;
662
663    fn div(self, other: f64) -> Expr {
664        let operand2 = Expr::new_leaf(other, &other.to_string());
665        self / operand2
666    }
667}
668
669/// Implements the [`Sum`] trait for the [`Expr`] struct.
670/// 
671/// Note that this implementation will generate temporary [`Expr`] objects,
672/// which may not be the most efficient way to sum a collection of [`Expr`] objects.
673/// However, it is provided as a convenience method for users that want to use sum
674/// over an [`Iterator<Expr>`].
675/// 
676/// Example:
677/// 
678/// ```rust
679/// use alpha_micrograd_rust::value::Expr;
680/// 
681/// let expr = Expr::new_leaf(1.0, "x");
682/// let expr2 = Expr::new_leaf(2.0, "y");
683/// let expr3 = Expr::new_leaf(3.0, "z");
684/// 
685/// let sum = vec![expr, expr2, expr3].into_iter().sum::<Expr>();
686/// 
687/// println!("Result: {}", sum.result); // 6.0
688/// ```
689impl Sum for Expr {
690    fn sum<I>(iter: I) -> Self
691    where
692        I: Iterator<Item = Self>,
693    {
694        iter.reduce(|acc, x| acc + x)
695            .unwrap_or(Expr::new_leaf(0.0, "0.0"))
696    }
697}
698
699#[cfg(test)]
700mod tests {
701    use super::*;
702
703    fn assert_float_eq(f1: f64, f2: f64) {
704        let places = 7;
705        let tolerance = 10.0_f64.powi(-places);
706        assert!((f1 - f2).abs() < tolerance, "{} != {} (tol: {})", f1, f2, tolerance);
707    }
708
709    #[test]
710    fn test() {
711        let expr = Expr::new_leaf(1.0, "x");
712        assert_eq!(expr.result, 1.0);
713    }
714
715    #[test]
716    fn test_unary() {
717        let expr = Expr::new_leaf(1.0, "x");
718        let expr2 = Expr::new_unary(expr, Operation::Tanh, 1.1, "tanh(x)");
719
720        assert_eq!(expr2.result, 1.1);
721        assert_eq!(expr2.operand1.unwrap().result, 1.0);
722    }
723
724    #[test]
725    #[should_panic]
726    fn test_unary_expression_type_check() {
727        let expr = Expr::new_leaf(1.0, "x");
728        let _expr2 = Expr::new_unary(expr, Operation::Add, 1.1, "tanh(x)");
729    }
730
731    #[test]
732    fn test_binary() {
733        let expr = Expr::new_leaf(1.0, "x");
734        let expr2 = Expr::new_leaf(2.0, "y");
735        let expr3 = Expr::new_binary(expr, expr2, Operation::Add, 1.1, "x + y");
736
737        assert_eq!(expr3.result, 1.1);
738        assert_eq!(expr3.operand1.unwrap().result, 1.0);
739        assert_eq!(expr3.operand2.unwrap().result, 2.0);
740    }
741
742    #[test]
743    #[should_panic]
744    fn test_binary_expression_type_check() {
745        let expr = Expr::new_leaf(1.0, "x");
746        let expr2 = Expr::new_leaf(2.0, "y");
747        let _expr3 = Expr::new_binary(expr, expr2, Operation::Tanh, 3.0, "x + y");
748    }
749
750    #[test]
751    fn test_mixed_tree() {
752        let expr = Expr::new_leaf(1.0, "x");
753        let expr2 = Expr::new_leaf(2.0, "y");
754        let expr3 = Expr::new_binary(expr, expr2, Operation::Sub, 1.1, "x - y");
755        let expr4 = Expr::new_unary(expr3, Operation::Tanh, 1.2, "tanh(x - y)");
756
757        assert_eq!(expr4.result, 1.2);
758        let expr3 = expr4.operand1.unwrap();
759        assert_eq!(expr3.result, 1.1);
760        assert_eq!(expr3.operand1.unwrap().result, 1.0);
761        assert_eq!(expr3.operand2.unwrap().result, 2.0);
762    }
763
764    #[test]
765    fn test_tanh() {
766        let expr = Expr::new_leaf(1.0, "x");
767        let expr2 = expr.tanh("tanh(x)");
768
769        assert_eq!(expr2.result, 0.7615941559557649);
770        assert!(expr2.operand1.is_some());
771        assert_eq!(expr2.operand1.unwrap().result, 1.0);
772        assert_eq!(expr2.operation, Operation::Tanh);
773        assert!(expr2.operand2.is_none());
774
775        // Some other known values
776        fn get_tanh(x: f64) -> f64 {
777            Expr::new_leaf(x, "x").tanh("tanh(x)").result
778        }
779
780        assert_float_eq(get_tanh(10.74), 0.9999999);
781        assert_float_eq(get_tanh(-10.74), -0.9999999);
782        assert_float_eq(get_tanh(0.0), 0.0);
783    }
784
785    #[test]
786    fn test_exp() {
787        let expr = Expr::new_leaf(1.0, "x");
788        let expr2 = expr.exp("exp(x)");
789
790        assert_eq!(expr2.result, 2.718281828459045);
791        assert!(expr2.operand1.is_some());
792        assert_eq!(expr2.operand1.unwrap().result, 1.0);
793        assert_eq!(expr2.operation, Operation::Exp);
794        assert!(expr2.operand2.is_none());
795    }
796
797    #[test]
798    fn test_relu() {
799        // negative case
800        let expr = Expr::new_leaf(-1.0, "x");
801        let expr2 = expr.relu("relu(x)");
802
803        assert_eq!(expr2.result, 0.0);
804        assert!(expr2.operand1.is_some());
805        assert_eq!(expr2.operand1.unwrap().result, -1.0);
806        assert_eq!(expr2.operation, Operation::ReLU);
807        assert!(expr2.operand2.is_none());
808
809        // positive case
810        let expr = Expr::new_leaf(1.0, "x");
811        let expr2 = expr.relu("relu(x)");
812
813        assert_eq!(expr2.result, 1.0);
814        assert!(expr2.operand1.is_some());
815        assert_eq!(expr2.operand1.unwrap().result, 1.0);
816        assert_eq!(expr2.operation, Operation::ReLU);
817        assert!(expr2.operand2.is_none());
818    }
819
820    #[test]
821    fn test_pow() {
822        let expr = Expr::new_leaf(2.0, "x");
823        let expr2 = Expr::new_leaf(3.0, "y");
824        let result = expr.pow(expr2, "x^y");
825
826        assert_eq!(result.result, 8.0);
827        assert!(result.operand1.is_some());
828        assert_eq!(result.operand1.unwrap().result, 2.0);
829        assert_eq!(result.operation, Operation::Pow);
830        
831        assert!(result.operand2.is_some());
832        assert_eq!(result.operand2.unwrap().result, 3.0);
833    }
834
835    #[test]
836    fn test_add() {
837        let expr = Expr::new_leaf(1.0, "x");
838        let expr2 = Expr::new_leaf(2.0, "y");
839        let expr3 = expr + expr2;
840
841        assert_eq!(expr3.result, 3.0);
842        assert!(expr3.operand1.is_some());
843        assert_eq!(expr3.operand1.unwrap().result, 1.0);
844        assert!(expr3.operand2.is_some());
845        assert_eq!(expr3.operand2.unwrap().result, 2.0);
846        assert_eq!(expr3.operation, Operation::Add);
847        assert_eq!(expr3.name, "(x + y)");
848    }
849
850    #[test]
851    fn test_add_f64() {
852        let expr = Expr::new_leaf(1.0, "x");
853        let expr2 = expr + 2.0;
854
855        assert_eq!(expr2.result, 3.0);
856        assert!(expr2.operand1.is_some());
857        assert_eq!(expr2.operand1.unwrap().result, 1.0);
858        assert!(expr2.operand2.is_some());
859        assert_eq!(expr2.operand2.unwrap().result, 2.0);
860        assert_eq!(expr2.operation, Operation::Add);
861        assert_eq!(expr2.name, "(x + 2)");
862    }
863
864    #[test]
865    fn test_add_f64_expr() {
866        let expr = Expr::new_leaf(1.0, "x");
867        let expr2 = 2.0 + expr;
868
869        assert_eq!(expr2.result, 3.0);
870        assert!(expr2.operand1.is_some());
871        assert_eq!(expr2.operand1.unwrap().result, 2.0);
872        assert!(expr2.operand2.is_some());
873        assert_eq!(expr2.operand2.unwrap().result, 1.0);
874        assert_eq!(expr2.operation, Operation::Add);
875        assert_eq!(expr2.name, "(2 + x)");
876    }
877
878    #[test]
879    fn test_mul() {
880        let expr = Expr::new_leaf(2.0, "x");
881        let expr2 = Expr::new_leaf(3.0, "y");
882        let expr3 = expr * expr2;
883
884        assert_eq!(expr3.result, 6.0);
885        assert!(expr3.operand1.is_some());
886        assert_eq!(expr3.operand1.unwrap().result, 2.0);
887        assert!(expr3.operand2.is_some());
888        assert_eq!(expr3.operand2.unwrap().result, 3.0);
889        assert_eq!(expr3.operation, Operation::Mul);
890        assert_eq!(expr3.name, "(x * y)");
891    }
892
893    #[test]
894    fn test_mul_f64() {
895        let expr = Expr::new_leaf(2.0, "x");
896        let expr2 = expr * 3.0;
897
898        assert_eq!(expr2.result, 6.0);
899        assert!(expr2.operand1.is_some());
900        assert_eq!(expr2.operand1.unwrap().result, 2.0);
901        assert!(expr2.operand2.is_some());
902        assert_eq!(expr2.operand2.unwrap().result, 3.0);
903        assert_eq!(expr2.operation, Operation::Mul);
904        assert_eq!(expr2.name, "(x * 3)");
905    }
906
907    #[test]
908    fn test_mul_f64_expr() {
909        let expr = Expr::new_leaf(2.0, "x");
910        let expr2 = 3.0 * expr;
911
912        assert_eq!(expr2.result, 6.0);
913        assert!(expr2.operand1.is_some());
914        assert_eq!(expr2.operand1.unwrap().result, 3.0);
915        assert!(expr2.operand2.is_some());
916        assert_eq!(expr2.operand2.unwrap().result, 2.0);
917        assert_eq!(expr2.operation, Operation::Mul);
918        assert_eq!(expr2.name, "(3 * x)");
919    }
920
921    #[test]
922    fn test_sub() {
923        let expr = Expr::new_leaf(2.0, "x");
924        let expr2 = Expr::new_leaf(3.0, "y");
925        let expr3 = expr - expr2;
926
927        assert_eq!(expr3.result, -1.0);
928        assert!(expr3.operand1.is_some());
929        assert_eq!(expr3.operand1.unwrap().result, 2.0);
930        assert!(expr3.operand2.is_some());
931        assert_eq!(expr3.operand2.unwrap().result, 3.0);
932        assert_eq!(expr3.operation, Operation::Sub);
933        assert_eq!(expr3.name, "(x - y)");
934    }
935
936    #[test]
937    fn test_sub_f64() {
938        let expr = Expr::new_leaf(2.0, "x");
939        let expr2 = expr - 3.0;
940
941        assert_eq!(expr2.result, -1.0);
942        assert!(expr2.operand1.is_some());
943        assert_eq!(expr2.operand1.unwrap().result, 2.0);
944        assert!(expr2.operand2.is_some());
945        assert_eq!(expr2.operand2.unwrap().result, 3.0);
946        assert_eq!(expr2.operation, Operation::Sub);
947        assert_eq!(expr2.name, "(x - 3)");
948    }
949
950    #[test]
951    fn test_sub_f64_expr() {
952        let expr = Expr::new_leaf(2.0, "x");
953        let expr2 = 3.0 - expr;
954
955        assert_eq!(expr2.result, 1.0);
956        assert!(expr2.operand1.is_some());
957        assert_eq!(expr2.operand1.unwrap().result, 3.0);
958        assert!(expr2.operand2.is_some());
959        assert_eq!(expr2.operand2.unwrap().result, 2.0);
960        assert_eq!(expr2.operation, Operation::Sub);
961        assert_eq!(expr2.name, "(3 - x)");
962    }
963
964    #[test]
965    fn test_div() {
966        let expr = Expr::new_leaf(6.0, "x");
967        let expr2 = Expr::new_leaf(3.0, "y");
968        let expr3 = expr / expr2;
969
970        assert_eq!(expr3.result, 2.0);
971        assert!(expr3.operand1.is_some());
972        assert_eq!(expr3.operand1.unwrap().result, 6.0);
973        assert!(expr3.operand2.is_some());
974        assert_eq!(expr3.operand2.unwrap().result, 3.0);
975        assert_eq!(expr3.operation, Operation::Div);
976        assert_eq!(expr3.name, "(x / y)");
977    }
978
979    #[test]
980    fn test_div_f64() {
981        let expr = Expr::new_leaf(6.0, "x");
982        let expr2 = expr / 3.0;
983
984        assert_eq!(expr2.result, 2.0);
985        assert!(expr2.operand1.is_some());
986        assert_eq!(expr2.operand1.unwrap().result, 6.0);
987        assert!(expr2.operand2.is_some());
988        assert_eq!(expr2.operand2.unwrap().result, 3.0);
989        assert_eq!(expr2.operation, Operation::Div);
990        assert_eq!(expr2.name, "(x / 3)");
991    }
992
993    #[test]
994    fn test_log() {
995        let expr = Expr::new_leaf(2.0, "x");
996        let expr2 = expr.log("log(x)");
997
998        assert_eq!(expr2.result, 0.6931471805599453);
999        assert!(expr2.operand1.is_some());
1000        assert_eq!(expr2.operand1.unwrap().result, 2.0);
1001        assert_eq!(expr2.operation, Operation::Log);
1002        assert!(expr2.operand2.is_none());
1003    }
1004
1005    #[test]
1006    fn test_neg() {
1007        let expr = Expr::new_leaf(2.0, "x");
1008        let expr2 = expr.neg("neg(x)");
1009
1010        assert_eq!(expr2.result, -2.0);
1011        assert!(expr2.operand1.is_some());
1012        assert_eq!(expr2.operand1.unwrap().result, 2.0);
1013        assert_eq!(expr2.operation, Operation::Neg);
1014        assert!(expr2.operand2.is_none());
1015    }
1016
1017    #[test]
1018    fn test_backpropagation_add() {
1019        let operand1 = Expr::new_leaf(1.0, "x");
1020        let operand2 = Expr::new_leaf(2.0, "y");
1021        let mut expr3 = operand1 + operand2;
1022
1023        expr3.learn(1e-09);
1024
1025        let operand1 = expr3.operand1.unwrap();
1026        let operand2 = expr3.operand2.unwrap();
1027        assert_eq!(operand1.grad, 1.0);
1028        assert_eq!(operand2.grad, 1.0);
1029    }
1030
1031    #[test]
1032    fn test_backpropagation_sub() {
1033        let operand1 = Expr::new_leaf(1.0, "x");
1034        let operand2 = Expr::new_leaf(2.0, "y");
1035        let mut expr3 = operand1 - operand2;
1036
1037        expr3.learn(1e-09);
1038
1039        let operand1 = expr3.operand1.unwrap();
1040        let operand2 = expr3.operand2.unwrap();
1041        assert_eq!(operand1.grad, 1.0);
1042        assert_eq!(operand2.grad, -1.0);
1043    }
1044
1045    #[test]
1046    fn test_backpropagation_mul() {
1047        let operand1 = Expr::new_leaf(3.0, "x");
1048        let operand2 = Expr::new_leaf(4.0, "y");
1049        let mut expr3 = operand1 * operand2;
1050
1051        expr3.learn(1e-09);
1052
1053        let operand1 = expr3.operand1.unwrap();
1054        let operand2 = expr3.operand2.unwrap();
1055        assert_eq!(operand1.grad, 4.0);
1056        assert_eq!(operand2.grad, 3.0);
1057    }
1058
1059    #[test]
1060    fn test_backpropagation_div() {
1061        let operand1 = Expr::new_leaf(3.0, "x");
1062        let operand2 = Expr::new_leaf(4.0, "y");
1063        let mut expr3 = operand1 / operand2;
1064
1065        expr3.learn(1e-09);
1066
1067        let operand1 = expr3.operand1.unwrap();
1068        let operand2 = expr3.operand2.unwrap();
1069        assert_eq!(operand1.grad, 0.25);
1070        assert_eq!(operand2.grad, -0.1875);
1071    }
1072
1073    #[test]
1074    fn test_backpropagation_tanh() {
1075        let operand1 = Expr::new_leaf(0.0, "x");
1076        let mut expr2 = operand1.tanh("tanh(x)");
1077
1078        expr2.learn(1e-09);
1079
1080        let operand1 = expr2.operand1.unwrap();
1081        assert_eq!(operand1.grad, 1.0);
1082    }
1083
1084    #[test]
1085    fn test_backpropagation_relu() {
1086        let operand1 = Expr::new_leaf(-1.0, "x");
1087        let mut expr2 = operand1.relu("relu(x)");
1088
1089        expr2.learn(1e-09);
1090
1091        let operand1 = expr2.operand1.unwrap();
1092        assert_eq!(operand1.grad, 0.0);
1093    }
1094
1095    #[test]
1096    fn test_backpropagation_exp() {
1097        let operand1 = Expr::new_leaf(0.0, "x");
1098        let mut expr2 = operand1.exp("exp(x)");
1099
1100        expr2.learn(1e-09);
1101
1102        let operand1 = expr2.operand1.unwrap();
1103        assert_eq!(operand1.grad, 1.0);
1104    }
1105
1106    #[test]
1107    fn test_backpropagation_pow() {
1108        let operand1 = Expr::new_leaf(2.0, "x");
1109        let operand2 = Expr::new_leaf(3.0, "y");
1110        let mut expr3 = operand1.pow(operand2, "x^y");
1111
1112        expr3.learn(1e-09);
1113
1114        let operand1 = expr3.operand1.unwrap();
1115        let operand2 = expr3.operand2.unwrap();
1116        assert_eq!(operand1.grad, 12.0);
1117        assert_eq!(operand2.grad, 5.545177444479562);
1118    }
1119
1120    #[test]
1121    fn test_backpropagation_mixed_tree() {
1122        let operand1 = Expr::new_leaf(1.0, "x");
1123        let operand2 = Expr::new_leaf(2.0, "y");
1124        let expr3 = operand1 + operand2;
1125        let mut expr4 = expr3.tanh("tanh(x + y)");
1126
1127        expr4.learn(1e-09);
1128
1129        let expr3 = expr4.operand1.unwrap();
1130        let operand1 = expr3.operand1.unwrap();
1131        let operand2 = expr3.operand2.unwrap();
1132
1133        assert_eq!(expr3.grad, 0.009866037165440211);
1134        assert_eq!(operand1.grad, 0.009866037165440211);
1135        assert_eq!(operand2.grad, 0.009866037165440211);
1136    }
1137
1138    #[test]
1139    fn test_backpropagation_karpathys_example() {
1140        let x1 = Expr::new_leaf(2.0, "x1");
1141        let x2 = Expr::new_leaf(0.0, "x2");
1142        let w1 = Expr::new_leaf(-3.0, "w1");
1143        let w2 = Expr::new_leaf(1.0, "w2");
1144        let b = Expr::new_leaf(6.8813735870195432, "b");
1145
1146        let x1w1 = x1 * w1;
1147        let x2w2 = x2 * w2;
1148        let x1w1_x2w2 = x1w1 + x2w2;
1149        let n = x1w1_x2w2 + b;
1150        let mut o = n.tanh("tanh(n)");
1151
1152        o.learn(1e-09);
1153
1154        assert_eq!(o.operation, Operation::Tanh);
1155        assert_eq!(o.grad, 1.0);
1156
1157        let n = o.operand1.unwrap();
1158        assert_eq!(n.operation, Operation::Add);
1159        assert_float_eq(n.grad, 0.5);
1160
1161        let x1w1_x2w2 = n.operand1.unwrap();
1162        assert_eq!(x1w1_x2w2.operation, Operation::Add);
1163        assert_float_eq(x1w1_x2w2.grad, 0.5);
1164
1165        let b = n.operand2.unwrap();
1166        assert_eq!(b.operation, Operation::None);
1167        assert_float_eq(b.grad, 0.5);
1168
1169        let x1w1 = x1w1_x2w2.operand1.unwrap();
1170        assert_eq!(x1w1.operation, Operation::Mul);
1171        assert_float_eq(x1w1.grad, 0.5);
1172
1173        let x2w2 = x1w1_x2w2.operand2.unwrap();
1174        assert_eq!(x2w2.operation, Operation::Mul);
1175        assert_float_eq(x2w2.grad, 0.5);
1176
1177        let x1 = x1w1.operand1.unwrap();
1178        assert_eq!(x1.operation, Operation::None);
1179        assert_float_eq(x1.grad, -1.5);
1180
1181        let w1 = x1w1.operand2.unwrap();
1182        assert_eq!(w1.operation, Operation::None);
1183        assert_float_eq(w1.grad, 1.0);
1184
1185        let x2 = x2w2.operand1.unwrap();
1186        assert_eq!(x2.operation, Operation::None);
1187        assert_float_eq(x2.grad, 0.5);
1188
1189        let w2 = x2w2.operand2.unwrap();
1190        assert_eq!(w2.operation, Operation::None);
1191        assert_float_eq(w2.grad, 0.0);
1192    }
1193
1194    #[test]
1195    fn test_learn_simple() {
1196        let mut expr = Expr::new_leaf(1.0, "x");
1197        expr.learn(1e-01);
1198
1199        assert_float_eq(expr.result, 0.9);
1200    }
1201
1202    #[test]
1203    fn test_learn_skips_non_learnable() {
1204        let mut expr = Expr::new_leaf(1.0, "x");
1205        expr.is_learnable = false;
1206        expr.learn(1e-01);
1207
1208        assert_float_eq(expr.result, 1.0);
1209    }
1210
1211    #[test]
1212    fn test_find_simple() {
1213        let expr = Expr::new_leaf(1.0, "x");
1214        let expr2 = expr.tanh("tanh(x)");
1215
1216        let found = expr2.find("x");
1217        assert!(found.is_some());
1218        assert_eq!(found.unwrap().name, "x");
1219    }
1220
1221    #[test]
1222    fn test_find_not_found() {
1223        let expr = Expr::new_leaf(1.0, "x");
1224        let expr2 = expr.tanh("tanh(x)");
1225
1226        let found = expr2.find("y");
1227        assert!(found.is_none());
1228    }
1229
1230    #[test]
1231    fn test_sum_iterator() {
1232        let expr = Expr::new_leaf(1.0, "x");
1233        let expr2 = Expr::new_leaf(2.0, "y");
1234        let expr3 = Expr::new_leaf(3.0, "z");
1235
1236        let sum: Expr = vec![expr, expr2, expr3].into_iter().sum::<Expr>();
1237        assert_eq!(sum.result, 6.0);
1238    }
1239
1240    #[test]
1241    fn test_find_after_clone() {
1242        let expr = Expr::new_leaf(1.0, "x");
1243        let expr2 = expr.tanh("tanh(x)");
1244        let expr2_clone = expr2.clone();
1245
1246        let found = expr2_clone.find("x");
1247        assert!(found.is_some());
1248        assert_eq!(found.unwrap().name, "x");
1249    }
1250}