alpha_micrograd_rust/
value.rs

1//! A simple library for creating and backpropagating through expression trees.
2//! 
3//! This package includes the following elements to construct expression trees:
4//! 
5//! - [`value::Expr`]: a node in the expression tree.
6#![deny(missing_docs)]
7use std::collections::VecDeque;
8use std::ops::{Add, Div, Mul, Sub};
9use std::iter::Sum;
10
11#[derive(Debug, Clone, Copy, PartialEq)]
12pub(crate) enum Operation {
13    None,
14    Add,
15    Sub,
16    Mul,
17    Div,
18    Tanh,
19    Exp,
20    Pow,
21    ReLU,
22    Log,
23    Neg,
24}
25
26impl Operation {
27    fn assert_is_type(&self, expr_type: ExprType) {
28        match self {
29            Operation::None => assert_eq!(expr_type, ExprType::Leaf),
30            Operation::Tanh | Operation::Exp | Operation::ReLU | Operation::Log | Operation::Neg => assert_eq!(expr_type, ExprType::Unary),
31            _ => assert_eq!(expr_type, ExprType::Binary),
32        }
33    }
34}
35
36#[derive(Debug, PartialEq)]
37pub(crate) enum ExprType {
38    Leaf,
39    Unary,
40    Binary,
41}
42
43/// Expression representing a node in a calculation graph.
44/// 
45/// This struct represents a node in a calculation graph. It can be a leaf node, a unary operation or a binary operation.
46/// 
47/// A leaf node holds a value, which is the one that is used in the calculation.
48/// 
49/// A unary expression is the result of applying a unary operation to another expression. For example, the result of applying the `tanh` operation to a leaf node.
50/// 
51/// A binary expression is the result of applying a binary operation to two other expressions. For example, the result of adding two leaf nodes.
52#[derive(Debug, Clone)]
53pub struct Expr {
54    pub(crate) operand1: Option<Box<Expr>>,
55    pub(crate) operand2: Option<Box<Expr>>,
56    /// The operation applied to the operands, if any.
57    pub(crate) operation: Operation,
58    /// The numeric result of the expression, as result of applying the operation to the operands.
59    pub result: f64,
60    /// Whether the expression is learnable or not. Only learnable [`Expr`] will have their values updated during backpropagation (learning).
61    pub is_learnable: bool,
62    pub(crate) grad: f64,
63    /// The name of the expression, used to identify it in the calculation graph.
64    pub name: Option<String>,
65}
66
67impl Expr {
68    /// Creates a new leaf expression with the given value.
69    /// 
70    /// Example:
71    /// ```rust
72    /// use alpha_micrograd_rust::value::Expr;
73    /// 
74    /// let expr = Expr::new_leaf(1.0);
75    /// ```
76    pub fn new_leaf(value: f64) -> Expr {
77        Expr {
78            operand1: None,
79            operand2: None,
80            operation: Operation::None,
81            result: value,
82            is_learnable: true,
83            grad: 0.0,
84            name: None,
85        }
86    }
87
88    /// Creates a new leaf expression with the given value and name.
89    /// 
90    /// Example:
91    /// ```rust
92    /// use alpha_micrograd_rust::value::Expr;
93    /// 
94    /// let expr = Expr::new_leaf_with_name(1.0, "x");
95    /// 
96    /// assert_eq!(expr.name, Some("x".to_string()));
97    /// ```
98    pub fn new_leaf_with_name(value: f64, name: &str) -> Expr {
99        let mut expr = Expr::new_leaf(value);
100        expr.name = Some(name.to_string());
101        expr
102    }
103
104    pub(crate) fn expr_type(&self) -> ExprType {
105        match self.operation {
106            Operation::None => ExprType::Leaf,
107            Operation::Tanh | Operation::Exp | Operation::ReLU | Operation::Log | Operation::Neg => ExprType::Unary,
108            _ => ExprType::Binary,
109        }
110    }
111
112    fn new_unary(operand: Expr, operation: Operation, result: f64) -> Expr {
113        operation.assert_is_type(ExprType::Unary);
114        Expr {
115            operand1: Some(Box::new(operand)),
116            operand2: None,
117            operation,
118            result,
119            is_learnable: false,
120            grad: 0.0,
121            name: None,
122        }
123    }
124
125    fn new_binary(operand1: Expr, operand2: Expr, operation: Operation, result: f64) -> Expr {
126        operation.assert_is_type(ExprType::Binary);
127        Expr {
128            operand1: Some(Box::new(operand1)),
129            operand2: Some(Box::new(operand2)),
130            operation,
131            result,
132            is_learnable: false,
133            grad: 0.0,
134            name: None,
135        }
136    }
137
138    /// Applies the hyperbolic tangent function to the expression and returns it as a new expression.
139    /// 
140    /// Example:
141    /// ```rust
142    /// use alpha_micrograd_rust::value::Expr;
143    /// 
144    /// let expr = Expr::new_leaf(1.0);
145    /// let expr2 = expr.tanh();
146    /// 
147    /// assert_eq!(expr2.result, 0.7615941559557649);
148    /// ```
149    pub fn tanh(self) -> Expr {
150        let result = self.result.tanh();
151        Expr::new_unary(self, Operation::Tanh, result)
152    }
153
154    /// Applies the rectified linear unit function to the expression and returns it as a new expression.
155    /// 
156    /// Example:
157    /// ```rust
158    /// use alpha_micrograd_rust::value::Expr;
159    /// 
160    /// let expr = Expr::new_leaf(-1.0);
161    /// let expr2 = expr.relu();
162    /// 
163    /// assert_eq!(expr2.result, 0.0);
164    /// ```
165    pub fn relu(self) -> Expr {
166        let result = self.result.max(0.0);
167        Expr::new_unary(self, Operation::ReLU, result)
168    }
169
170    /// Applies the exponential function (e^x) to the expression and returns it as a new expression.
171    /// 
172    /// Example:
173    /// ```rust
174    /// use alpha_micrograd_rust::value::Expr;
175    /// 
176    /// let expr = Expr::new_leaf(1.0);
177    /// let expr2 = expr.exp();
178    /// 
179    /// assert_eq!(expr2.result, 2.718281828459045);
180    /// ```
181    pub fn exp(self) -> Expr {
182        let result = self.result.exp();
183        Expr::new_unary(self, Operation::Exp, result)
184    }
185
186    /// Raises the expression to the power of the given exponent (expression) and returns it as a new expression.
187    /// 
188    /// Example:
189    /// ```rust
190    /// use alpha_micrograd_rust::value::Expr;
191    /// 
192    /// let expr = Expr::new_leaf(2.0);
193    /// let exponent = Expr::new_leaf(3.0);
194    /// let result = expr.pow(exponent);
195    /// 
196    /// assert_eq!(result.result, 8.0);
197    /// ```
198    pub fn pow(self, exponent: Expr) -> Expr {
199        let result = self.result.powf(exponent.result);
200        Expr::new_binary(self, exponent, Operation::Pow, result)
201    }
202
203    /// Applies the natural logarithm function to the expression and returns it as a new expression.
204    /// 
205    /// Example:
206    /// ```rust
207    /// use alpha_micrograd_rust::value::Expr;
208    /// 
209    /// let expr = Expr::new_leaf(2.0);
210    /// let expr2 = expr.log();
211    /// 
212    /// assert_eq!(expr2.result, 0.6931471805599453);
213    /// ```
214    pub fn log(self) -> Expr {
215        let result = self.result.ln();
216        Expr::new_unary(self, Operation::Log, result)
217    }
218
219    /// Negates the expression and returns it as a new expression.
220    /// 
221    /// Example:
222    /// ```rust
223    /// use alpha_micrograd_rust::value::Expr;
224    /// 
225    /// let expr = Expr::new_leaf(1.0);
226    /// let expr2 = expr.neg();
227    /// 
228    /// assert_eq!(expr2.result, -1.0);
229    /// ```
230    pub fn neg(self) -> Expr {
231        let result = -self.result;
232        Expr::new_unary(self, Operation::Neg, result)
233    }
234
235    /// Recalculates the value of the expression recursively, from new values of the operands.
236    /// 
237    /// Usually will be used after a call to [`Expr::learn`], where the gradients have been calculated and
238    /// the internal values of the expression tree have been updated.
239    /// 
240    /// Example:
241    /// ```rust
242    /// use alpha_micrograd_rust::value::Expr;
243    /// 
244    /// let expr = Expr::new_leaf(1.0);
245    /// let mut expr2 = expr.tanh();
246    /// expr2.learn(1e-09);
247    /// expr2.recalculate();
248    /// 
249    /// assert_eq!(expr2.result, 0.7615941557793864);
250    /// ```
251    /// 
252    /// You can also vary the values of the operands and recalculate the expression:
253    /// ```rust
254    /// use alpha_micrograd_rust::value::Expr;
255    /// 
256    /// let expr = Expr::new_leaf_with_name(1.0, "x");
257    /// let mut expr2 = expr.tanh();
258    /// 
259    /// let mut original = expr2.find_mut("x").expect("Could not find x");
260    /// original.result = 2.0;
261    /// expr2.recalculate();
262    /// 
263    /// assert_eq!(expr2.result, 0.9640275800758169);
264    /// ```
265    pub fn recalculate(&mut self) {
266        // TODO: Since we can't borrow the operands mutably without inferring multible borrows from
267        // the current node, this approach will need to stay recursive for now.
268        // We can replace it with an iterative approach after we implement an allocation arena at the
269        // tree level and then we can just visit them in a regular loop.
270        match self.expr_type() {
271            ExprType::Leaf => {}
272            ExprType::Unary => {
273                let operand1 = self.operand1.as_mut().expect("Unary expression did not have an operand");
274                operand1.recalculate();
275
276                self.result = match self.operation {
277                    Operation::Tanh => operand1.result.tanh(),
278                    Operation::Exp => operand1.result.exp(),
279                    Operation::ReLU => operand1.result.max(0.0),
280                    Operation::Log => operand1.result.ln(),
281                    Operation::Neg => -operand1.result,
282                    _ => panic!("Invalid unary operation {:?}", self.operation),
283                };
284            }
285            ExprType::Binary => {
286                let operand1 = self.operand1.as_mut().expect("Binary expression did not have an operand");
287                let operand2 = self.operand2.as_mut().expect("Binary expression did not have a second operand");
288
289                operand1.recalculate();
290                operand2.recalculate();
291
292                self.result = match self.operation {
293                    Operation::Add => operand1.result + operand2.result,
294                    Operation::Sub => operand1.result - operand2.result,
295                    Operation::Mul => operand1.result * operand2.result,
296                    Operation::Div => operand1.result / operand2.result,
297                    Operation::Pow => operand1.result.powf(operand2.result),
298                    _ => panic!("Invalid binary operation: {:?}", self.operation),
299                };
300            }
301        }
302    }
303
304    /// Applies backpropagation to the expression, updating the values of the
305    /// gradients and the expression itself.
306    /// 
307    /// This method will change the gradients based on the gradient of the last
308    /// expression in the calculation graph.
309    /// 
310    /// Example:
311    /// 
312    /// ```rust
313    /// use alpha_micrograd_rust::value::Expr;
314    /// 
315    /// let expr = Expr::new_leaf(1.0);
316    /// let mut expr2 = expr.tanh();
317    /// expr2.learn(1e-09);
318    /// ```
319    /// 
320    /// After adjusting the gradients, the method will update the values of the
321    /// individual expression tree nodes to minimize the loss function.
322    /// 
323    /// In order to get a new calculation of the expression tree, you'll need to call
324    /// [`Expr::recalculate`] after calling [`Expr::learn`].
325    pub fn learn(&mut self, learning_rate: f64) {
326        self.grad = 1.0;
327
328        let mut queue = VecDeque::from([self]);
329
330        while let Some(node) = queue.pop_front() {
331            match node.expr_type() {
332                ExprType::Leaf => {
333                    node.learn_internal_leaf(learning_rate);
334                }
335                ExprType::Unary => {
336                    let operand1 = node.operand1.as_mut().expect("Unary expression did not have an operand");
337                    operand1.adjust_grad_unary(&node.operation, node.grad, node.result);
338                    queue.push_back(operand1);
339                }
340                ExprType::Binary => {
341                    let operand1 = node.operand1.as_mut().expect("Binary expression did not have an operand");
342                    let operand2 = node.operand2.as_mut().expect("Binary expression did not have a second operand");
343
344                    operand1.adjust_grad_binary_op1(&node.operation, node.grad, operand2);
345                    operand2.adjust_grad_binary_op2(&node.operation, node.grad, operand1);
346
347                    queue.push_back(operand1);
348                    queue.push_back(operand2);
349                }
350            }
351        }
352    }
353
354    fn learn_internal_leaf(&mut self, learning_rate: f64) {
355        // leaves have their gradient set externally by other nodes in the tree
356        // leaves can be learnable, in which case we update the value
357        if self.is_learnable {
358            self.result -= learning_rate * self.grad;
359        }
360    }
361
362    fn adjust_grad_unary(&mut self, child_operation: &Operation, child_grad: f64, child_result: f64) {
363        match child_operation {
364            Operation::Tanh => {
365                let tanh_grad = 1.0 - (child_result * child_result);
366                self.grad = child_grad * tanh_grad;
367            }
368            Operation::Exp => {
369                self.grad = child_grad * child_result;
370            }
371            Operation::ReLU => {
372                self.grad = child_grad * if child_result > 0.0 { 1.0 } else { 0.0 };
373            }
374            Operation::Log => {
375                self.grad = child_grad / child_result;
376            }
377            Operation::Neg => {
378                self.grad = -child_grad;
379            }
380            _ => panic!("Invalid unary operation {:?}", child_operation),
381        }
382    }
383
384    fn adjust_grad_binary_op1(&mut self, child_operation: &Operation, child_grad: f64, operand2: &Expr) {
385        match child_operation {
386            Operation::Add => {
387                self.grad = child_grad;
388            }
389            Operation::Sub => {
390                self.grad = child_grad;
391            }
392            Operation::Mul => {
393                let operand2_result = operand2.result;
394
395                self.grad = child_grad * operand2_result;
396            }
397            Operation::Div => {
398                let operand2_result = operand2.result;
399
400                self.grad = child_grad / operand2_result;
401            }
402            Operation::Pow => {
403                let exponent = operand2.result;
404                let base = self.result;
405
406                self.grad = child_grad * exponent * base.powf(exponent - 1.0);
407            }
408            _ => panic!("Invalid binary operation: {:?}", child_operation),
409        }
410    }
411
412    fn adjust_grad_binary_op2(&mut self,child_operation: &Operation, child_grad: f64, operand1: &Expr) {
413        match child_operation {
414            Operation::Add => {
415                self.grad = child_grad;
416            }
417            Operation::Sub => {
418                self.grad = -child_grad;
419            }
420            Operation::Mul => {
421                let operand1_result = operand1.result;
422                self.grad = child_grad * operand1_result;
423            }
424            Operation::Div => {
425                let operand2_result = self.result;
426                let operand1_result = operand1.result;
427
428                self.grad = -child_grad * operand1_result / (operand2_result * operand2_result);
429            }
430            Operation::Pow => {
431                let exponent = self.result;
432                let base = operand1.result;
433
434                self.grad = child_grad * base.powf(exponent) * base.ln();
435            }
436            _ => panic!("Invalid binary operation: {:?}", child_operation),
437        }
438    }
439
440    /// Finds a node in the expression tree by its name.
441    /// 
442    /// This method will search the expression tree for a node with the given name.
443    /// If the node is not found, it will return [None].
444    /// 
445    /// Example:
446    /// ```rust
447    /// use alpha_micrograd_rust::value::Expr;
448    /// 
449    /// let expr = Expr::new_leaf_with_name(1.0, "x");
450    /// let expr2 = expr.tanh();
451    /// let original = expr2.find("x");
452    /// 
453    /// assert_eq!(original.expect("Could not find x").result, 1.0);
454    /// ```
455    pub fn find(&self, name: &str) -> Option<&Expr> {
456        let mut stack = vec![self];
457
458        while let Some(node) = stack.pop() {
459            if node.name == Some(name.to_string()) {
460                return Some(node);
461            }
462
463            if let Some(operand1) = node.operand1.as_ref() {
464                stack.push(operand1);
465            }
466            if let Some(operand2) = node.operand2.as_ref() {
467                stack.push(operand2);
468            }
469        }
470
471        None
472    }
473
474    /// Finds a node in the expression tree by its name and returns a mutable reference to it.
475    /// 
476    /// This method will search the expression tree for a node with the given name.
477    /// If the node is not found, it will return [None].
478    /// 
479    /// Example:
480    /// ```rust
481    /// use alpha_micrograd_rust::value::Expr;
482    /// 
483    /// let expr = Expr::new_leaf_with_name(1.0, "x");
484    /// let mut expr2 = expr.tanh();
485    /// let mut original = expr2.find_mut("x").expect("Could not find x");
486    /// original.result = 2.0;
487    /// expr2.recalculate();
488    /// 
489    /// assert_eq!(expr2.result, 0.9640275800758169);
490    /// ```
491    pub fn find_mut(&mut self, name: &str) -> Option<&mut Expr> {
492        let mut stack = vec![self];
493
494        while let Some(node) = stack.pop() {
495            if node.name == Some(name.to_string()) {
496                return Some(node);
497            }
498
499            if let Some(operand1) = node.operand1.as_mut() {
500                stack.push(operand1);
501            }
502            if let Some(operand2) = node.operand2.as_mut() {
503                stack.push(operand2);
504            }
505        }
506
507        None
508    }
509
510    /// Returns the count of nodes (parameters)in the expression tree.
511    /// 
512    /// This method will return the total number of nodes in the expression tree,
513    /// including the root node.
514    /// 
515    /// Example:
516    /// ```rust
517    /// use alpha_micrograd_rust::value::Expr;
518    /// 
519    /// let expr = Expr::new_leaf(1.0);
520    /// let expr2 = expr.tanh();
521    /// 
522    /// assert_eq!(expr2.parameter_count(false), 2);
523    /// assert_eq!(expr2.parameter_count(true), 1);
524    /// ```
525    pub fn parameter_count(&self, learnable_only: bool) -> usize {
526        let mut stack = vec![self];
527        let mut count = 0;
528
529        while let Some(node) = stack.pop() {
530            if node.is_learnable || !learnable_only {
531                count += 1;
532            }
533
534            if let Some(operand1) = node.operand1.as_ref() {
535                stack.push(operand1);
536            }
537            if let Some(operand2) = node.operand2.as_ref() {
538                stack.push(operand2);
539            }
540        }
541
542        count
543    }
544}
545
546/// Implements the [`Add`] trait for the [`Expr`] struct.
547/// 
548/// This implementation allows the addition of two [`Expr`] objects.
549/// 
550/// Example:
551/// ```rust
552/// use alpha_micrograd_rust::value::Expr;
553/// 
554/// let expr = Expr::new_leaf(1.0);
555/// let expr2 = Expr::new_leaf(2.0);
556/// 
557/// let result = expr + expr2;
558/// 
559/// assert_eq!(result.result, 3.0);
560/// ```
561impl Add for Expr {
562    type Output = Expr;
563
564    fn add(self, other: Expr) -> Expr {
565        let result = self.result + other.result;
566        Expr::new_binary(self, other, Operation::Add, result)
567    }
568}
569
570/// Implements the [`Add`] trait for the [`Expr`] struct, when the right operand is a [`f64`].
571/// 
572/// This implementation allows the addition of an [`Expr`] object and a [`f64`] value.
573/// 
574/// Example:
575/// ```rust
576/// use alpha_micrograd_rust::value::Expr;
577/// 
578/// let expr = Expr::new_leaf(1.0);
579/// let result = expr + 2.0;
580/// 
581/// assert_eq!(result.result, 3.0);
582/// ```
583impl Add<f64> for Expr {
584    type Output = Expr;
585
586    fn add(self, other: f64) -> Expr {
587        let operand2 = Expr::new_leaf(other);
588        self + operand2
589    }
590}
591
592/// Implements the [`Add`] trait for the [`f64`] type, when the right operand is an [`Expr`].
593/// 
594/// This implementation allows the addition of a [`f64`] value and an [`Expr`] object.
595/// 
596/// Example:
597/// ```rust
598/// use alpha_micrograd_rust::value::Expr;
599/// 
600/// let expr = Expr::new_leaf(1.0);
601/// let result = 2.0 + expr;
602/// 
603/// assert_eq!(result.result, 3.0);
604/// ```
605impl Add<Expr> for f64 {
606    type Output = Expr;
607
608    fn add(self, other: Expr) -> Expr {
609        let operand1 = Expr::new_leaf(self);
610        operand1 + other
611    }
612}
613
614/// Implements the [`Mul`] trait for the [`Expr`] struct.
615/// 
616/// This implementation allows the multiplication of two [`Expr`] objects.
617/// 
618/// Example:
619/// 
620/// ```rust
621/// use alpha_micrograd_rust::value::Expr;
622/// 
623/// let expr = Expr::new_leaf(1.0);
624/// let expr2 = Expr::new_leaf(2.0);
625/// 
626/// let result = expr * expr2;
627/// 
628/// assert_eq!(result.result, 2.0);
629/// ```
630impl Mul for Expr {
631    type Output = Expr;
632
633    fn mul(self, other: Expr) -> Expr {
634        let result = self.result * other.result;
635        Expr::new_binary(self, other, Operation::Mul, result)
636    }
637}
638
639/// Implements the [`Mul`] trait for the [`Expr`] struct, when the right operand is a [`f64`].
640/// 
641/// This implementation allows the multiplication of an [`Expr`] object and a [`f64`] value.
642/// 
643/// Example:
644/// 
645/// ```rust
646/// use alpha_micrograd_rust::value::Expr;
647/// 
648/// let expr = Expr::new_leaf(1.0);
649/// let result = expr * 2.0;
650/// 
651/// assert_eq!(result.result, 2.0);
652/// ```
653impl Mul<f64> for Expr {
654    type Output = Expr;
655
656    fn mul(self, other: f64) -> Expr {
657        let operand2 = Expr::new_leaf(other);
658        self * operand2
659    }
660}
661
662/// Implements the [`Mul`] trait for the [`f64`] type, when the right operand is an [`Expr`].
663/// 
664/// This implementation allows the multiplication of a [`f64`] value and an [`Expr`] object.
665/// 
666/// Example:
667/// 
668/// ```rust
669/// use alpha_micrograd_rust::value::Expr;
670/// 
671/// let expr = Expr::new_leaf(1.0);
672/// let result = 2.0 * expr;
673/// 
674/// assert_eq!(result.result, 2.0);
675/// ```
676impl Mul<Expr> for f64 {
677    type Output = Expr;
678
679    fn mul(self, other: Expr) -> Expr {
680        let operand1 = Expr::new_leaf(self);
681        operand1 * other
682    }
683}
684
685/// Implements the [`Sub`] trait for the [`Expr`] struct.
686/// 
687/// This implementation allows the subtraction of two [`Expr`] objects.
688/// 
689/// Example:
690/// 
691/// ```rust
692/// use alpha_micrograd_rust::value::Expr;
693/// 
694/// let expr = Expr::new_leaf(1.0);
695/// let expr2 = Expr::new_leaf(2.0);
696/// 
697/// let result = expr - expr2;
698/// 
699/// assert_eq!(result.result, -1.0);
700/// ```
701impl Sub for Expr {
702    type Output = Expr;
703
704    fn sub(self, other: Expr) -> Expr {
705        let result = self.result - other.result;
706        Expr::new_binary(self, other, Operation::Sub, result)
707    }
708}
709
710/// Implements the [`Sub`] trait for the [`Expr`] struct, when the right operand is a [`f64`].
711/// 
712/// This implementation allows the subtraction of an [`Expr`] object and a [`f64`] value.
713/// 
714/// Example:
715/// 
716/// ```rust
717/// use alpha_micrograd_rust::value::Expr;
718/// 
719/// let expr = Expr::new_leaf(1.0);
720/// let result = expr - 2.0;
721/// 
722/// assert_eq!(result.result, -1.0);
723/// ```
724impl Sub<f64> for Expr {
725    type Output = Expr;
726
727    fn sub(self, other: f64) -> Expr {
728        let operand2 = Expr::new_leaf(other);
729        self - operand2
730    }
731}
732
733/// Implements the [`Sub`] trait for the [`f64`] type, when the right operand is an [`Expr`].
734/// 
735/// This implementation allows the subtraction of a [`f64`] value and an [`Expr`] object.
736/// 
737/// Example:
738/// 
739/// ```rust
740/// use alpha_micrograd_rust::value::Expr;
741/// 
742/// let expr = Expr::new_leaf(1.0);
743/// let result = 2.0 - expr;
744/// 
745/// assert_eq!(result.result, 1.0);
746/// ```
747impl Sub<Expr> for f64 {
748    type Output = Expr;
749
750    fn sub(self, other: Expr) -> Expr {
751        let operand1 = Expr::new_leaf(self);
752        operand1 - other
753    }
754}
755
756/// Implements the [`Div`] trait for the [`Expr`] struct.
757/// 
758/// This implementation allows the division of two [`Expr`] objects.
759/// 
760/// Example:
761/// 
762/// ```rust
763/// use alpha_micrograd_rust::value::Expr;
764/// 
765/// let expr = Expr::new_leaf(1.0);
766/// let expr2 = Expr::new_leaf(2.0);
767/// 
768/// let result = expr / expr2;
769/// 
770/// assert_eq!(result.result, 0.5);
771/// ```
772impl Div for Expr {
773    type Output = Expr;
774
775    fn div(self, other: Expr) -> Expr {
776        let result = self.result / other.result;
777        Expr::new_binary(self, other, Operation::Div, result)
778    }
779}
780
781/// Implements the [`Div`] trait for the [`Expr`] struct, when the right operand is a [`f64`].
782/// 
783/// This implementation allows the division of an [`Expr`] object and a [`f64`] value.
784/// 
785/// Example:
786/// 
787/// ```rust
788/// use alpha_micrograd_rust::value::Expr;
789/// 
790/// let expr = Expr::new_leaf(1.0);
791/// let result = expr / 2.0;
792/// 
793/// assert_eq!(result.result, 0.5);
794/// ```
795impl Div<f64> for Expr {
796    type Output = Expr;
797
798    fn div(self, other: f64) -> Expr {
799        let operand2 = Expr::new_leaf(other);
800        self / operand2
801    }
802}
803
804/// Implements the [`Sum`] trait for the [`Expr`] struct.
805/// 
806/// Note that this implementation will generate temporary [`Expr`] objects,
807/// which may not be the most efficient way to sum a collection of [`Expr`] objects.
808/// However, it is provided as a convenience method for users that want to use sum
809/// over an [`Iterator<Expr>`].
810/// 
811/// Example:
812/// 
813/// ```rust
814/// use alpha_micrograd_rust::value::Expr;
815/// 
816/// let expr = Expr::new_leaf(1.0);
817/// let expr2 = Expr::new_leaf(2.0);
818/// let expr3 = Expr::new_leaf(3.0);
819/// 
820/// let sum = vec![expr, expr2, expr3].into_iter().sum::<Expr>();
821/// 
822/// assert_eq!(sum.result, 6.0);
823/// ```
824impl Sum for Expr {
825    fn sum<I>(iter: I) -> Self
826    where
827        I: Iterator<Item = Self>,
828    {
829        iter.reduce(|acc, x| acc + x)
830            .unwrap_or(Expr::new_leaf(0.0))
831    }
832}
833
834#[cfg(test)]
835mod tests {
836    use super::*;
837
838    fn assert_float_eq(f1: f64, f2: f64) {
839        let places = 7;
840        let tolerance = 10.0_f64.powi(-places);
841        assert!((f1 - f2).abs() < tolerance, "{} != {} (tol: {})", f1, f2, tolerance);
842    }
843
844    #[test]
845    fn test() {
846        let expr = Expr::new_leaf(1.0);
847        assert_eq!(expr.result, 1.0);
848    }
849
850    #[test]
851    fn test_unary() {
852        let expr = Expr::new_leaf(1.0);
853        let expr2 = Expr::new_unary(expr, Operation::Tanh, 1.1);
854
855        assert_eq!(expr2.result, 1.1);
856        assert_eq!(expr2.operand1.unwrap().result, 1.0);
857    }
858
859    #[test]
860    #[should_panic]
861    fn test_unary_expression_type_check() {
862        let expr = Expr::new_leaf(1.0);
863        let _expr2 = Expr::new_unary(expr, Operation::Add, 1.1);
864    }
865
866    #[test]
867    fn test_binary() {
868        let expr = Expr::new_leaf(1.0);
869        let expr2 = Expr::new_leaf(2.0);
870        let expr3 = Expr::new_binary(expr, expr2, Operation::Add, 1.1);
871
872        assert_eq!(expr3.result, 1.1);
873        assert_eq!(expr3.operand1.unwrap().result, 1.0);
874        assert_eq!(expr3.operand2.unwrap().result, 2.0);
875    }
876
877    #[test]
878    #[should_panic]
879    fn test_binary_expression_type_check() {
880        let expr = Expr::new_leaf(1.0);
881        let expr2 = Expr::new_leaf(2.0);
882        let _expr3 = Expr::new_binary(expr, expr2, Operation::Tanh, 3.0);
883    }
884
885    #[test]
886    fn test_mixed_tree() {
887        let expr = Expr::new_leaf(1.0);
888        let expr2 = Expr::new_leaf(2.0);
889        let expr3 = Expr::new_binary(expr, expr2, Operation::Sub, 1.1);
890        let expr4 = Expr::new_unary(expr3, Operation::Tanh, 1.2);
891
892        assert_eq!(expr4.result, 1.2);
893        let expr3 = expr4.operand1.unwrap();
894        assert_eq!(expr3.result, 1.1);
895        assert_eq!(expr3.operand1.unwrap().result, 1.0);
896        assert_eq!(expr3.operand2.unwrap().result, 2.0);
897    }
898
899    #[test]
900    fn test_tanh() {
901        let expr = Expr::new_leaf(1.0);
902        let expr2 = expr.tanh();
903
904        assert_eq!(expr2.result, 0.7615941559557649);
905        assert!(expr2.operand1.is_some());
906        assert_eq!(expr2.operand1.unwrap().result, 1.0);
907        assert_eq!(expr2.operation, Operation::Tanh);
908        assert!(expr2.operand2.is_none());
909
910        // Some other known values
911        fn get_tanh(x: f64) -> f64 {
912            Expr::new_leaf(x).tanh().result
913        }
914
915        assert_float_eq(get_tanh(10.74), 0.9999999);
916        assert_float_eq(get_tanh(-10.74), -0.9999999);
917        assert_float_eq(get_tanh(0.0), 0.0);
918    }
919
920    #[test]
921    fn test_exp() {
922        let expr = Expr::new_leaf(1.0);
923        let expr2 = expr.exp();
924
925        assert_eq!(expr2.result, 2.718281828459045);
926        assert!(expr2.operand1.is_some());
927        assert_eq!(expr2.operand1.unwrap().result, 1.0);
928        assert_eq!(expr2.operation, Operation::Exp);
929        assert!(expr2.operand2.is_none());
930    }
931
932    #[test]
933    fn test_relu() {
934        // negative case
935        let expr = Expr::new_leaf(-1.0);
936        let expr2 = expr.relu();
937
938        assert_eq!(expr2.result, 0.0);
939        assert!(expr2.operand1.is_some());
940        assert_eq!(expr2.operand1.unwrap().result, -1.0);
941        assert_eq!(expr2.operation, Operation::ReLU);
942        assert!(expr2.operand2.is_none());
943
944        // positive case
945        let expr = Expr::new_leaf(1.0);
946        let expr2 = expr.relu();
947
948        assert_eq!(expr2.result, 1.0);
949        assert!(expr2.operand1.is_some());
950        assert_eq!(expr2.operand1.unwrap().result, 1.0);
951        assert_eq!(expr2.operation, Operation::ReLU);
952        assert!(expr2.operand2.is_none());
953    }
954
955    #[test]
956    fn test_pow() {
957        let expr = Expr::new_leaf(2.0);
958        let expr2 = Expr::new_leaf(3.0);
959        let result = expr.pow(expr2);
960
961        assert_eq!(result.result, 8.0);
962        assert!(result.operand1.is_some());
963        assert_eq!(result.operand1.unwrap().result, 2.0);
964        assert_eq!(result.operation, Operation::Pow);
965        
966        assert!(result.operand2.is_some());
967        assert_eq!(result.operand2.unwrap().result, 3.0);
968    }
969
970    #[test]
971    fn test_add() {
972        let expr = Expr::new_leaf(1.0);
973        let expr2 = Expr::new_leaf(2.0);
974        let expr3 = expr + expr2;
975
976        assert_eq!(expr3.result, 3.0);
977        assert!(expr3.operand1.is_some());
978        assert_eq!(expr3.operand1.unwrap().result, 1.0);
979        assert!(expr3.operand2.is_some());
980        assert_eq!(expr3.operand2.unwrap().result, 2.0);
981        assert_eq!(expr3.operation, Operation::Add);
982    }
983
984    #[test]
985    fn test_add_f64() {
986        let expr = Expr::new_leaf(1.0);
987        let expr2 = expr + 2.0;
988
989        assert_eq!(expr2.result, 3.0);
990        assert!(expr2.operand1.is_some());
991        assert_eq!(expr2.operand1.unwrap().result, 1.0);
992        assert!(expr2.operand2.is_some());
993        assert_eq!(expr2.operand2.unwrap().result, 2.0);
994        assert_eq!(expr2.operation, Operation::Add);
995    }
996
997    #[test]
998    fn test_add_f64_expr() {
999        let expr = Expr::new_leaf(1.0);
1000        let expr2 = 2.0 + expr;
1001
1002        assert_eq!(expr2.result, 3.0);
1003        assert!(expr2.operand1.is_some());
1004        assert_eq!(expr2.operand1.unwrap().result, 2.0);
1005        assert!(expr2.operand2.is_some());
1006        assert_eq!(expr2.operand2.unwrap().result, 1.0);
1007        assert_eq!(expr2.operation, Operation::Add);
1008    }
1009
1010    #[test]
1011    fn test_mul() {
1012        let expr = Expr::new_leaf(2.0);
1013        let expr2 = Expr::new_leaf(3.0);
1014        let expr3 = expr * expr2;
1015
1016        assert_eq!(expr3.result, 6.0);
1017        assert!(expr3.operand1.is_some());
1018        assert_eq!(expr3.operand1.unwrap().result, 2.0);
1019        assert!(expr3.operand2.is_some());
1020        assert_eq!(expr3.operand2.unwrap().result, 3.0);
1021        assert_eq!(expr3.operation, Operation::Mul);
1022    }
1023
1024    #[test]
1025    fn test_mul_f64() {
1026        let expr = Expr::new_leaf(2.0);
1027        let expr2 = expr * 3.0;
1028
1029        assert_eq!(expr2.result, 6.0);
1030        assert!(expr2.operand1.is_some());
1031        assert_eq!(expr2.operand1.unwrap().result, 2.0);
1032        assert!(expr2.operand2.is_some());
1033        assert_eq!(expr2.operand2.unwrap().result, 3.0);
1034        assert_eq!(expr2.operation, Operation::Mul);
1035    }
1036
1037    #[test]
1038    fn test_mul_f64_expr() {
1039        let expr = Expr::new_leaf(2.0);
1040        let expr2 = 3.0 * expr;
1041
1042        assert_eq!(expr2.result, 6.0);
1043        assert!(expr2.operand1.is_some());
1044        assert_eq!(expr2.operand1.unwrap().result, 3.0);
1045        assert!(expr2.operand2.is_some());
1046        assert_eq!(expr2.operand2.unwrap().result, 2.0);
1047        assert_eq!(expr2.operation, Operation::Mul);
1048    }
1049
1050    #[test]
1051    fn test_sub() {
1052        let expr = Expr::new_leaf(2.0);
1053        let expr2 = Expr::new_leaf(3.0);
1054        let expr3 = expr - expr2;
1055
1056        assert_eq!(expr3.result, -1.0);
1057        assert!(expr3.operand1.is_some());
1058        assert_eq!(expr3.operand1.unwrap().result, 2.0);
1059        assert!(expr3.operand2.is_some());
1060        assert_eq!(expr3.operand2.unwrap().result, 3.0);
1061        assert_eq!(expr3.operation, Operation::Sub);
1062    }
1063
1064    #[test]
1065    fn test_sub_f64() {
1066        let expr = Expr::new_leaf(2.0);
1067        let expr2 = expr - 3.0;
1068
1069        assert_eq!(expr2.result, -1.0);
1070        assert!(expr2.operand1.is_some());
1071        assert_eq!(expr2.operand1.unwrap().result, 2.0);
1072        assert!(expr2.operand2.is_some());
1073        assert_eq!(expr2.operand2.unwrap().result, 3.0);
1074        assert_eq!(expr2.operation, Operation::Sub);
1075    }
1076
1077    #[test]
1078    fn test_sub_f64_expr() {
1079        let expr = Expr::new_leaf(2.0);
1080        let expr2 = 3.0 - expr;
1081
1082        assert_eq!(expr2.result, 1.0);
1083        assert!(expr2.operand1.is_some());
1084        assert_eq!(expr2.operand1.unwrap().result, 3.0);
1085        assert!(expr2.operand2.is_some());
1086        assert_eq!(expr2.operand2.unwrap().result, 2.0);
1087        assert_eq!(expr2.operation, Operation::Sub);
1088    }
1089
1090    #[test]
1091    fn test_div() {
1092        let expr = Expr::new_leaf(6.0);
1093        let expr2 = Expr::new_leaf(3.0);
1094        let expr3 = expr / expr2;
1095
1096        assert_eq!(expr3.result, 2.0);
1097        assert!(expr3.operand1.is_some());
1098        assert_eq!(expr3.operand1.unwrap().result, 6.0);
1099        assert!(expr3.operand2.is_some());
1100        assert_eq!(expr3.operand2.unwrap().result, 3.0);
1101        assert_eq!(expr3.operation, Operation::Div);
1102    }
1103
1104    #[test]
1105    fn test_div_f64() {
1106        let expr = Expr::new_leaf(6.0);
1107        let expr2 = expr / 3.0;
1108
1109        assert_eq!(expr2.result, 2.0);
1110        assert!(expr2.operand1.is_some());
1111        assert_eq!(expr2.operand1.unwrap().result, 6.0);
1112        assert!(expr2.operand2.is_some());
1113        assert_eq!(expr2.operand2.unwrap().result, 3.0);
1114        assert_eq!(expr2.operation, Operation::Div);
1115    }
1116
1117    #[test]
1118    fn test_log() {
1119        let expr = Expr::new_leaf(2.0);
1120        let expr2 = expr.log();
1121
1122        assert_eq!(expr2.result, 0.6931471805599453);
1123        assert!(expr2.operand1.is_some());
1124        assert_eq!(expr2.operand1.unwrap().result, 2.0);
1125        assert_eq!(expr2.operation, Operation::Log);
1126        assert!(expr2.operand2.is_none());
1127    }
1128
1129    #[test]
1130    fn test_neg() {
1131        let expr = Expr::new_leaf(2.0);
1132        let expr2 = expr.neg();
1133
1134        assert_eq!(expr2.result, -2.0);
1135        assert!(expr2.operand1.is_some());
1136        assert_eq!(expr2.operand1.unwrap().result, 2.0);
1137        assert_eq!(expr2.operation, Operation::Neg);
1138        assert!(expr2.operand2.is_none());
1139    }
1140
1141    #[test]
1142    fn test_backpropagation_add() {
1143        let operand1 = Expr::new_leaf(1.0);
1144        let operand2 = Expr::new_leaf(2.0);
1145        let mut expr3 = operand1 + operand2;
1146
1147        expr3.learn(1e-09);
1148
1149        let operand1 = expr3.operand1.unwrap();
1150        let operand2 = expr3.operand2.unwrap();
1151        assert_eq!(operand1.grad, 1.0);
1152        assert_eq!(operand2.grad, 1.0);
1153    }
1154
1155    #[test]
1156    fn test_backpropagation_sub() {
1157        let operand1 = Expr::new_leaf(1.0);
1158        let operand2 = Expr::new_leaf(2.0);
1159        let mut expr3 = operand1 - operand2;
1160
1161        expr3.learn(1e-09);
1162
1163        let operand1 = expr3.operand1.unwrap();
1164        let operand2 = expr3.operand2.unwrap();
1165        assert_eq!(operand1.grad, 1.0);
1166        assert_eq!(operand2.grad, -1.0);
1167    }
1168
1169    #[test]
1170    fn test_backpropagation_mul() {
1171        let operand1 = Expr::new_leaf(3.0);
1172        let operand2 = Expr::new_leaf(4.0);
1173        let mut expr3 = operand1 * operand2;
1174
1175        expr3.learn(1e-09);
1176
1177        let operand1 = expr3.operand1.unwrap();
1178        let operand2 = expr3.operand2.unwrap();
1179        assert_eq!(operand1.grad, 4.0);
1180        assert_eq!(operand2.grad, 3.0);
1181    }
1182
1183    #[test]
1184    fn test_backpropagation_div() {
1185        let operand1 = Expr::new_leaf(3.0);
1186        let operand2 = Expr::new_leaf(4.0);
1187        let mut expr3 = operand1 / operand2;
1188
1189        expr3.learn(1e-09);
1190
1191        let operand1 = expr3.operand1.unwrap();
1192        let operand2 = expr3.operand2.unwrap();
1193        assert_eq!(operand1.grad, 0.25);
1194        assert_eq!(operand2.grad, -0.1875);
1195    }
1196
1197    #[test]
1198    fn test_backpropagation_tanh() {
1199        let operand1 = Expr::new_leaf(0.0);
1200        let mut expr2 = operand1.tanh();
1201
1202        expr2.learn(1e-09);
1203
1204        let operand1 = expr2.operand1.unwrap();
1205        assert_eq!(operand1.grad, 1.0);
1206    }
1207
1208    #[test]
1209    fn test_backpropagation_relu() {
1210        let operand1 = Expr::new_leaf(-1.0);
1211        let mut expr2 = operand1.relu();
1212
1213        expr2.learn(1e-09);
1214
1215        let operand1 = expr2.operand1.unwrap();
1216        assert_eq!(operand1.grad, 0.0);
1217    }
1218
1219    #[test]
1220    fn test_backpropagation_exp() {
1221        let operand1 = Expr::new_leaf(0.0);
1222        let mut expr2 = operand1.exp();
1223
1224        expr2.learn(1e-09);
1225
1226        let operand1 = expr2.operand1.unwrap();
1227        assert_eq!(operand1.grad, 1.0);
1228    }
1229
1230    #[test]
1231    fn test_backpropagation_pow() {
1232        let operand1 = Expr::new_leaf(2.0);
1233        let operand2 = Expr::new_leaf(3.0);
1234        let mut expr3 = operand1.pow(operand2);
1235
1236        expr3.learn(1e-09);
1237
1238        let operand1 = expr3.operand1.unwrap();
1239        let operand2 = expr3.operand2.unwrap();
1240        assert_eq!(operand1.grad, 12.0);
1241        assert_eq!(operand2.grad, 5.545177444479562);
1242    }
1243
1244    #[test]
1245    fn test_backpropagation_mixed_tree() {
1246        let operand1 = Expr::new_leaf(1.0);
1247        let operand2 = Expr::new_leaf(2.0);
1248        let expr3 = operand1 + operand2;
1249        let mut expr4 = expr3.tanh();
1250
1251        expr4.learn(1e-09);
1252
1253        let expr3 = expr4.operand1.unwrap();
1254        let operand1 = expr3.operand1.unwrap();
1255        let operand2 = expr3.operand2.unwrap();
1256
1257        assert_eq!(expr3.grad, 0.009866037165440211);
1258        assert_eq!(operand1.grad, 0.009866037165440211);
1259        assert_eq!(operand2.grad, 0.009866037165440211);
1260    }
1261
1262    #[test]
1263    fn test_backpropagation_karpathys_example() {
1264        let x1 = Expr::new_leaf(2.0);
1265        let x2 = Expr::new_leaf(0.0);
1266        let w1 = Expr::new_leaf(-3.0);
1267        let w2 = Expr::new_leaf(1.0);
1268        let b = Expr::new_leaf(6.8813735870195432);
1269
1270        let x1w1 = x1 * w1;
1271        let x2w2 = x2 * w2;
1272        let x1w1_x2w2 = x1w1 + x2w2;
1273        let n = x1w1_x2w2 + b;
1274        let mut o = n.tanh();
1275
1276        o.learn(1e-09);
1277
1278        assert_eq!(o.operation, Operation::Tanh);
1279        assert_eq!(o.grad, 1.0);
1280
1281        let n = o.operand1.unwrap();
1282        assert_eq!(n.operation, Operation::Add);
1283        assert_float_eq(n.grad, 0.5);
1284
1285        let x1w1_x2w2 = n.operand1.unwrap();
1286        assert_eq!(x1w1_x2w2.operation, Operation::Add);
1287        assert_float_eq(x1w1_x2w2.grad, 0.5);
1288
1289        let b = n.operand2.unwrap();
1290        assert_eq!(b.operation, Operation::None);
1291        assert_float_eq(b.grad, 0.5);
1292
1293        let x1w1 = x1w1_x2w2.operand1.unwrap();
1294        assert_eq!(x1w1.operation, Operation::Mul);
1295        assert_float_eq(x1w1.grad, 0.5);
1296
1297        let x2w2 = x1w1_x2w2.operand2.unwrap();
1298        assert_eq!(x2w2.operation, Operation::Mul);
1299        assert_float_eq(x2w2.grad, 0.5);
1300
1301        let x1 = x1w1.operand1.unwrap();
1302        assert_eq!(x1.operation, Operation::None);
1303        assert_float_eq(x1.grad, -1.5);
1304
1305        let w1 = x1w1.operand2.unwrap();
1306        assert_eq!(w1.operation, Operation::None);
1307        assert_float_eq(w1.grad, 1.0);
1308
1309        let x2 = x2w2.operand1.unwrap();
1310        assert_eq!(x2.operation, Operation::None);
1311        assert_float_eq(x2.grad, 0.5);
1312
1313        let w2 = x2w2.operand2.unwrap();
1314        assert_eq!(w2.operation, Operation::None);
1315        assert_float_eq(w2.grad, 0.0);
1316    }
1317
1318    #[test]
1319    fn test_learn_simple() {
1320        let mut expr = Expr::new_leaf(1.0);
1321        expr.learn(1e-01);
1322
1323        assert_float_eq(expr.result, 0.9);
1324    }
1325
1326    #[test]
1327    fn test_learn_skips_non_learnable() {
1328        let mut expr = Expr::new_leaf(1.0);
1329        expr.is_learnable = false;
1330        expr.learn(1e-01);
1331
1332        assert_float_eq(expr.result, 1.0);
1333    }
1334
1335    #[test]
1336    fn test_find_simple() {
1337        let expr = Expr::new_leaf_with_name(1.0, "x");
1338        let expr2 = expr.tanh();
1339
1340        let found = expr2.find("x");
1341        assert!(found.is_some());
1342        assert_eq!(found.unwrap().name, Some("x".to_string()));
1343    }
1344
1345    #[test]
1346    fn test_find_not_found() {
1347        let expr = Expr::new_leaf_with_name(1.0, "x");
1348        let expr2 = expr.tanh();
1349
1350        let found = expr2.find("y");
1351        assert!(found.is_none());
1352    }
1353
1354    #[test]
1355    fn test_sum_iterator() {
1356        let expr = Expr::new_leaf(1.0);
1357        let expr2 = Expr::new_leaf(2.0);
1358        let expr3 = Expr::new_leaf(3.0);
1359
1360        let sum: Expr = vec![expr, expr2, expr3].into_iter().sum::<Expr>();
1361        assert_eq!(sum.result, 6.0);
1362    }
1363
1364    #[test]
1365    fn test_find_after_clone() {
1366        let expr = Expr::new_leaf_with_name(1.0, "x");
1367        let expr2 = expr.tanh();
1368        let expr2_clone = expr2.clone();
1369
1370        let found = expr2_clone.find("x");
1371        assert!(found.is_some());
1372        assert_eq!(found.unwrap().name, Some("x".to_string()));
1373    }
1374}