alpha_micrograd_rust/
compiled.rs

1//! A compiled version of the expression tree for faster computation.
2//!
3//! This module provides a compiled version of the expression tree, which is faster to compute
4//! and backpropagate through than the regular expression tree, at the expense of not being
5//! able to modify it further.
6//! 
7//! This module contains the following elements:
8//! 
9//! - [`CompiledExpr`]: A struct that represents a compiled version of an expression tree.
10#![deny(missing_docs)]
11use std::collections::HashMap;
12
13use crate::value::{Expr, Operation};
14
15/// A compiled version of an expression tree.
16///
17/// This struct represents a compiled version of an expression tree, which is faster to compute
18/// and backpropagate through than the regular expression tree.
19pub struct CompiledExpr {
20    operations: Vec<Operation>,
21    lhs: Vec<Option<usize>>,
22    rhs: Vec<Option<usize>>,
23    /// The results of the operations in the expression.
24    pub results: Vec<f64>,
25    /// The gradients of the operations in the expression.
26    pub gradients: Vec<f64>,
27    is_learnable: Vec<bool>,
28    names_to_index: HashMap<String, usize>,
29}
30
31impl CompiledExpr {
32    fn consume_expr(&mut self, expr: Expr) {
33        let lhs = if let Some(operand1) = expr.operand1 {
34            self.consume_expr(*operand1);
35            Some(self.results.len() - 1)
36        } else {
37            None
38        };
39
40        let rhs = if let Some(operand2) = expr.operand2 {
41            self.consume_expr(*operand2);
42            Some(self.results.len() - 1)
43        } else {
44            None
45        };
46
47        self.lhs.push(lhs);
48        self.rhs.push(rhs);
49        self.results.push(expr.result);
50        self.operations.push(expr.operation);
51        self.gradients.push(expr.grad);
52        self.is_learnable.push(expr.is_learnable);
53        if let Some(name) = expr.name {
54            self.names_to_index.insert(name, self.results.len() - 1);
55        }
56    }
57
58    /// Creates a new `CompiledExpr` from an expression.
59    ///
60    /// This method consumes the expression and transforms it into a compiled form
61    /// that is more efficient for computation and backpropagation.
62    ///
63    /// Example:
64    /// 
65    /// ```rust
66    /// use alpha_micrograd_rust::value::Expr;
67    /// use alpha_micrograd_rust::compiled::CompiledExpr;
68    ///
69    /// let expr = Expr::new_leaf(1.0);
70    /// let expr2 = expr.tanh();
71    /// let compiled = CompiledExpr::from_expr(expr2);
72    /// ```
73    pub fn from_expr(expr: Expr) -> Self {
74        let parameter_count = expr.parameter_count(false);
75        let mut tape = CompiledExpr {
76            operations: Vec::with_capacity(parameter_count),
77            lhs: Vec::with_capacity(parameter_count),
78            rhs: Vec::with_capacity(parameter_count),
79            results: Vec::with_capacity(parameter_count),
80            gradients: Vec::with_capacity(parameter_count),
81            is_learnable: Vec::with_capacity(parameter_count),
82            names_to_index: HashMap::new(),
83        };
84
85        tape.consume_expr(expr);
86
87        tape
88    }
89
90    /// Recalculates the expression based on the current values.
91    ///
92    /// This method recalculates the expression based on the current values of the parameters.
93    /// It is more efficient than recalculating the expression tree, as it iterates through
94    /// an array of operations instead of traversing a tree structure.
95    ///
96    /// Example:
97    /// 
98    /// ```rust
99    /// use alpha_micrograd_rust::value::Expr;
100    /// use alpha_micrograd_rust::compiled::CompiledExpr;
101    ///
102    /// let expr = Expr::new_leaf_with_name(1.0, "x");
103    /// let expr2 = expr.tanh();
104    /// let mut compiled = CompiledExpr::from_expr(expr2);
105    /// assert_eq!(compiled.result(), 0.7615941559557649);
106    /// 
107    /// // Modify the value of "x"
108    /// compiled.set("x", 2.0);
109    /// compiled.recalculate();
110    /// 
111    /// assert_eq!(compiled.result(), 0.9640275800758169);
112    /// ```
113    pub fn recalculate(&mut self) {
114        for i in 0..self.results.len() {
115            let operation = self.operations[i];
116            let lhs_index = self.lhs[i];
117            let rhs_index = self.rhs[i];
118
119            let lhs_value = if let Some(index) = lhs_index {
120                self.results[index]
121            } else {
122                0.0 // Default value for leaf nodes
123            };
124
125            let rhs_value = if let Some(index) = rhs_index {
126                self.results[index]
127            } else {
128                0.0 // Default value for leaf nodes
129            };
130
131            self.results[i] = match operation {
132                Operation::Add => lhs_value + rhs_value,
133                Operation::Sub => lhs_value - rhs_value,
134                Operation::Mul => lhs_value * rhs_value,
135                Operation::Div => lhs_value / rhs_value,
136                Operation::None => self.results[i], // No operation, keep the value
137                Operation::Tanh => lhs_value.tanh(),
138                Operation::Exp => lhs_value.exp(),
139                Operation::Pow => lhs_value.powf(rhs_value),
140                Operation::Log => lhs_value.ln(),
141                Operation::ReLU => lhs_value.max(0.0),
142                Operation::Neg => -lhs_value,
143            };
144        }
145    }
146
147    /// Performs one step of learning (backpropagation) on the compiled expression.
148    ///
149    /// This function updates the values of the learnable parameters in the expression
150    /// based on the gradients calculated during backpropagation.
151    ///
152    /// # Arguments
153    ///
154    /// * `learning_rate` - The learning rate to use for updating the parameters.
155    ///
156    /// # Returns
157    ///
158    /// Returns nothing. The results are updated in place.
159    ///
160    /// Applies backpropagation to the expression, updating the values of the gradients and the expression itself.
161    ///
162    /// This method will change the gradients based on the gradient of the last expression in the
163    /// calculation graph. After adjusting the gradients, the method will update the values of
164    /// the individual expression nodes (parameters) to minimize the loss function.
165    ///
166    /// Example:
167    /// 
168    /// ```rust
169    /// use alpha_micrograd_rust::value::Expr;
170    /// use alpha_micrograd_rust::compiled::CompiledExpr;
171    ///
172    /// let expr = Expr::new_leaf(1.0);
173    /// let expr2 = expr.tanh();
174    /// let mut compiled = CompiledExpr::from_expr(expr2);
175    /// compiled.learn(1e-09);
176    /// compiled.recalculate();
177    /// ```
178    pub fn learn(&mut self, learning_rate: f64) {
179        // set last gradient to 1.0
180        self.gradients[self.results.len() - 1] = 1.0;
181
182        for i in (0..self.results.len()).rev() {
183            let operation = self.operations[i];
184            let lhs_index = self.lhs[i].unwrap_or(0);
185            let rhs_index = self.rhs[i].unwrap_or(0);
186
187            let lhs_result = if let Some(index) = self.lhs[i] {
188                self.results[index]
189            } else {
190                0.0 // Default value for leaf nodes
191            };
192
193            let rhs_result = if let Some(index) = self.rhs[i] {
194                self.results[index]
195            } else {
196                0.0 // Default value for leaf nodes
197            };
198            let result = self.results[i];
199            let gradient = self.gradients[i];
200
201            match operation {
202                // Learnable leaves
203                Operation::None => {
204                    // For learnable leaves only, update the result directly
205                    // (the gradient is already set by a previous operation)
206                    if self.is_learnable[i] {
207                        self.results[i] -= learning_rate * self.gradients[i];
208                    }
209                }
210                // Unary operations
211                Operation::Tanh => {
212                    let tanh_grad = 1.0 - (result * result);
213                    self.gradients[lhs_index] = gradient * tanh_grad;
214                }
215                Operation::Exp => {
216                    self.gradients[lhs_index] = gradient * result;
217                }
218                Operation::ReLU => {
219                    self.gradients[lhs_index] = if result > 0.0 {
220                        1.0
221                    } else {
222                        0.0
223                    };
224                }
225                Operation::Log => {
226                    self.gradients[lhs_index] = gradient / result;
227                }
228                Operation::Neg => {
229                    self.gradients[lhs_index] = -gradient;
230                }
231                // Binary operations
232                Operation::Add => {
233                    self.gradients[lhs_index] = gradient;
234                    self.gradients[rhs_index] = gradient;
235                }
236                Operation::Sub => {
237                    self.gradients[lhs_index] = gradient;
238                    self.gradients[rhs_index] = -gradient;
239                }
240                Operation::Mul => {
241                    self.gradients[lhs_index] = gradient * rhs_result;
242                    self.gradients[rhs_index] = gradient * lhs_result;
243                }
244                Operation::Div => {
245                    self.gradients[lhs_index] = gradient / rhs_result;
246                    self.gradients[rhs_index] = -gradient * lhs_result / (rhs_result * rhs_result);
247                }
248                Operation::Pow => {
249                    let exponent = rhs_result;
250                    let base = lhs_result;
251
252                    self.gradients[lhs_index] = gradient * exponent * base.powf(exponent - 1.0);
253                    self.gradients[rhs_index] = gradient * lhs_result.ln() * result;
254                }
255            }
256        }
257    }
258
259    /// Returns the final result of the compiled expression.
260    ///
261    /// This function returns the last result in the results vector, which corresponds
262    /// to the final output of the expression.
263    ///
264    /// # Returns
265    ///
266    /// Returns the final result as a `f64` value.
267    pub fn result(&self) -> f64 {
268        if self.results.is_empty() {
269            0.0
270        } else {
271            *self.results.last().unwrap()
272        }
273    }
274
275    /// Gets the gradient of a learnable parameter by its name.
276    ///
277    /// This function retrieves the gradient of a learnable parameter (e.g., a weight or bias)
278    /// by looking up its name in the names-to-index mapping.
279    ///
280    /// # Arguments
281    ///
282    /// * `name` - The name of the parameter to get the gradient for.
283    ///
284    /// # Returns
285    ///
286    /// Returns an `Option<f64>` containing the gradient value if found, or `None` if not found.
287    pub fn get_grad_by_name(&self, name: &str) -> Option<f64> {
288        if let Some(&index) = self.names_to_index.get(name) {
289            return Some(self.gradients[index]);
290        }
291        None
292    }
293
294    /// Sets the value of a parameter by its name.
295    ///
296    /// This method sets the value of a parameter in the compiled expression.
297    /// It is used to modify the values of leaf nodes in the expression tree.
298    ///
299    /// Example:
300    /// 
301    /// ```rust
302    /// use alpha_micrograd_rust::value::Expr;
303    /// use alpha_micrograd_rust::compiled::CompiledExpr;
304    ///
305    /// let expr = Expr::new_leaf_with_name(1.0, "x");
306    /// let expr2 = expr.tanh();
307    /// let mut compiled = CompiledExpr::from_expr(expr2);
308    /// 
309    /// compiled.set("x", 2.0);
310    /// compiled.recalculate();
311    /// 
312    /// assert_eq!(compiled.result(), 0.9640275800758169);
313    /// ```
314    pub fn set(&mut self, name: &str, value: f64) {
315        if let Some(&index) = self.names_to_index.get(name) {
316            self.results[index] = value;
317        }
318    }
319}
320
321#[cfg(test)]
322mod tests {
323    use super::*;
324
325    fn assert_float_eq(f1: f64, f2: f64) {
326        let places = 7;
327        let tolerance = 10.0_f64.powi(-places);
328        assert!((f1 - f2).abs() < tolerance, "{} != {} (tol: {})", f1, f2, tolerance);
329    }
330
331    #[test]
332    fn test_from_expr_multilevel() {
333        // Create a multilevel expression: (a + b) * (c - d)
334        // where a=2.0, b=3.0, c=5.0, d=1.0
335        // This should result in (2.0 + 3.0) * (5.0 - 1.0) = 5.0 * 4.0 = 20.0
336
337        // Create leaf nodes
338        let a = Expr::new_leaf(2.0);
339        let b = Expr::new_leaf_with_name(3.0, "b");
340        let c = Expr::new_leaf(5.0);
341        let d = Expr::new_leaf_with_name(1.0, "d");
342
343        // Create (a + b)
344        let add = a + b;
345
346        // Create (c - d)
347        let sub = c - d;
348
349        // Create (a + b) * (c - d)
350        let mul = add * sub;
351
352        // Convert to tape
353        let tape = CompiledExpr::from_expr(mul);
354
355        // Verify that all elements of the tape have the same length
356        assert_eq!(tape.results.len(), 7);
357        assert_eq!(tape.operations.len(), 7);
358        assert_eq!(tape.lhs.len(), 7);
359        assert_eq!(tape.rhs.len(), 7);
360        assert_eq!(tape.gradients.len(), 7);
361
362        // Verify each operation in the tape
363        // a: leaf, 2.0
364        assert_eq!(tape.results[0], 2.0);
365        assert_eq!(tape.lhs[0], None);
366        assert_eq!(tape.rhs[0], None); // Leaf node
367        assert_eq!(tape.operations[0], Operation::None);
368        assert_eq!(tape.gradients[0], 0.0); // Default gradient for leaf
369
370        // b: leaf, 3.0
371        assert_eq!(tape.results[1], 3.0);
372        assert_eq!(tape.lhs[1], None);
373        assert_eq!(tape.rhs[1], None); // Leaf node
374        assert_eq!(tape.operations[1], Operation::None);
375        assert_eq!(tape.gradients[1], 0.0); // Default gradient for leaf
376
377        // add: (a + b)
378        assert_eq!(tape.results[2], 5.0);
379        assert_eq!(tape.lhs[2], Some(0)); // Index of a
380        assert_eq!(tape.rhs[2], Some(1)); // Index of b
381        assert_eq!(tape.operations[2], Operation::Add);
382        assert_eq!(tape.gradients[2], 0.0); // Default gradient for result
383
384        // c: leaf, 5.0
385        assert_eq!(tape.results[3], 5.0);
386        assert_eq!(tape.lhs[3], None);
387        assert_eq!(tape.rhs[3], None); // Leaf node
388        assert_eq!(tape.operations[3], Operation::None);
389        assert_eq!(tape.gradients[3], 0.0); // Default gradient for leaf
390
391        // d: leaf, 1.0
392        assert_eq!(tape.results[4], 1.0);
393        assert_eq!(tape.lhs[4], None);
394        assert_eq!(tape.rhs[4], None); // Leaf node
395        assert_eq!(tape.operations[4], Operation::None);
396        assert_eq!(tape.gradients[4], 0.0); // Default gradient for leaf
397
398        // sub: (c - d)
399        assert_eq!(tape.results[5], 4.0);
400        assert_eq!(tape.lhs[5], Some(3)); // Index of c
401        assert_eq!(tape.rhs[5], Some(4)); // Index of d
402        assert_eq!(tape.operations[5], Operation::Sub);
403        assert_eq!(tape.gradients[5], 0.0); // Default gradient for result
404
405        // mul: (a + b) * (c - d)
406        assert_eq!(tape.results[6], 20.0);
407        assert_eq!(tape.lhs[6], Some(2)); // Index of add
408        assert_eq!(tape.rhs[6], Some(5)); // Index of sub
409        assert_eq!(tape.operations[6], Operation::Mul);
410        assert_eq!(tape.gradients[6], 0.0); // Default gradient for result
411
412        // Verify names to index mapping
413        assert_eq!(tape.names_to_index.get("b"), Some(&1));
414        assert_eq!(tape.names_to_index.get("d"), Some(&4));
415        assert!(tape.names_to_index.get("a").is_none());
416        assert!(tape.names_to_index.get("c").is_none());
417    }
418
419    #[test]
420    fn test_recalculate() {
421        // Create a simple expression: a + b
422        let a = Expr::new_leaf(2.0);
423        let b = Expr::new_leaf(3.0);
424        let expr = a + b;
425
426        // Convert to tape
427        let mut tape = CompiledExpr::from_expr(expr);
428
429        // Recalculate the results
430        tape.recalculate();
431
432        // Verify the result
433        assert_eq!(tape.results[2], 5.0); // Result of a + b
434
435        tape.results[0] = 4.0; // Change a to 4.0
436        tape.results[1] = 6.0; // Change b to 6.0
437        tape.recalculate();
438
439        // Verify the recalculated result
440        assert_eq!(tape.results[2], 10.0); // Result of 4.0 + 6.0
441    }
442
443    #[test]
444    fn test_learn_simple() {
445        let expr = Expr::new_leaf(1.0);
446        let mut tape = CompiledExpr::from_expr(expr);
447        assert_eq!(tape.result(), 1.0);
448
449        tape.learn(1e-01);
450        assert_eq!(tape.result(), 0.9); // 1.0 - 0.1 = 0.9
451    }
452
453    #[test]
454    fn test_learn_skips_non_learnable() {
455        let mut expr = Expr::new_leaf(1.0);
456        expr.is_learnable = false;
457        let mut tape = CompiledExpr::from_expr(expr);
458        assert_eq!(tape.result(), 1.0);
459
460        tape.learn(1e-01);
461        assert_eq!(tape.result(), 1.0);
462    }
463
464    #[test]
465    fn test_learn_multilevel() {
466        let expr = Expr::new_leaf(1.0);
467        let expr2 = expr.tanh();
468        let mut tape = CompiledExpr::from_expr(expr2);
469        assert_eq!(tape.result(), 0.7615941559557649); // tanh(1.0)
470        tape.learn(1e-09);
471        tape.recalculate();
472
473        assert_eq!(tape.result(), 0.7615941557793864);
474    }
475
476    #[test]
477    fn test_backpropagation_add() {
478        let mut operand1 = Expr::new_leaf(1.0);
479        operand1.name = Some("a".to_string());
480
481        let mut operand2 = Expr::new_leaf(2.0);
482        operand2.name = Some("b".to_string());
483
484        let expr3 = operand1 + operand2;
485        let mut tape = CompiledExpr::from_expr(expr3);
486
487        tape.learn(1e-09);
488
489        let grad_a = tape.get_grad_by_name("a").unwrap();
490        let grad_b = tape.get_grad_by_name("b").unwrap();
491        assert_eq!(grad_a, 1.0);
492        assert_eq!(grad_b, 1.0);
493    }
494
495    #[test]
496    fn test_backpropagation_sub() {
497        let mut operand1 = Expr::new_leaf(1.0);
498        operand1.name = Some("a".to_string());
499
500        let mut operand2 = Expr::new_leaf(2.0);
501        operand2.name = Some("b".to_string());
502
503        let expr3 = operand1 - operand2;
504        let mut tape = CompiledExpr::from_expr(expr3);
505        tape.learn(1e-09);
506
507        let grad_a = tape.get_grad_by_name("a").unwrap();
508        let grad_b = tape.get_grad_by_name("b").unwrap();
509        assert_eq!(grad_a, 1.0);
510        assert_eq!(grad_b, -1.0);
511    }
512
513    #[test]
514    fn test_backpropagation_mul() {
515        let mut operand1 = Expr::new_leaf(3.0);
516        operand1.name = Some("a".to_string());
517
518        let mut operand2 = Expr::new_leaf(4.0);
519        operand2.name = Some("b".to_string());
520
521        let expr3 = operand1 * operand2;
522        let mut tape = CompiledExpr::from_expr(expr3);
523
524        tape.learn(1e-09);
525
526        let grad_a = tape.get_grad_by_name("a").unwrap();
527        let grad_b = tape.get_grad_by_name("b").unwrap();
528        assert_eq!(grad_a, 4.0);
529        assert_eq!(grad_b, 3.0);
530    }
531
532    #[test]
533    fn test_backpropagation_div() {
534        let mut operand1 = Expr::new_leaf(3.0);
535        operand1.name = Some("a".to_string());
536
537        let mut operand2 = Expr::new_leaf(4.0);
538        operand2.name = Some("b".to_string());
539        let expr3 = operand1 / operand2;
540        let mut tape = CompiledExpr::from_expr(expr3);
541
542        tape.learn(1e-09);
543
544        let grad_a = tape.get_grad_by_name("a").unwrap();
545        let grad_b = tape.get_grad_by_name("b").unwrap();
546        assert_eq!(grad_a, 0.25);
547        assert_eq!(grad_b, -0.1875);
548    }
549
550    #[test]
551    fn test_backpropagation_tanh() {
552        let mut operand1 = Expr::new_leaf(0.0);
553        operand1.name = Some("a".to_string());
554        let expr2 = operand1.tanh();
555        let mut tape = CompiledExpr::from_expr(expr2);
556
557        tape.learn(1e-09);
558
559        let grad_a = tape.get_grad_by_name("a").unwrap();
560        assert_float_eq(grad_a, 1.0);
561    }
562
563    #[test]
564    fn test_backpropagation_relu() {
565        let mut operand1 = Expr::new_leaf(-1.0);
566        operand1.name = Some("a".to_string());
567        let expr2 = operand1.relu();
568        let mut tape = CompiledExpr::from_expr(expr2);
569
570        tape.learn(1e-09);
571
572        let grad_a = tape.get_grad_by_name("a").unwrap();
573        assert_eq!(grad_a, 0.0);
574    }
575
576    #[test]
577    fn test_backpropagation_exp() {
578        let mut operand1 = Expr::new_leaf(0.0);
579        operand1.name = Some("a".to_string());
580        let expr2 = operand1.exp();
581        let mut tape = CompiledExpr::from_expr(expr2);
582
583        tape.learn(1e-09);
584
585        let grad_a = tape.get_grad_by_name("a").unwrap();
586        assert_eq!(grad_a, 1.0);
587    }
588
589    #[test]
590    fn test_backpropagation_pow() {
591        let mut operand1 = Expr::new_leaf(2.0);
592        operand1.name = Some("a".to_string());
593        let mut operand2 = Expr::new_leaf(3.0);
594        operand2.name = Some("b".to_string());
595        let expr3 = operand1.pow(operand2);
596        let mut tape = CompiledExpr::from_expr(expr3);
597
598        tape.learn(1e-09);
599
600        let grad_a = tape.get_grad_by_name("a").unwrap();
601        let grad_b = tape.get_grad_by_name("b").unwrap();
602        assert_eq!(grad_a, 12.0);
603        assert_eq!(grad_b, 5.545177444479562);
604    }
605
606    #[test]
607    fn test_backpropagation_mixed_tree() {
608        let mut operand1 = Expr::new_leaf(1.0);
609        operand1.name = Some("operand1".to_string());
610        let mut operand2 = Expr::new_leaf(2.0);
611        operand2.name = Some("operand2".to_string());
612        let mut expr3 = operand1 + operand2;
613        expr3.name = Some("expr3".to_string());
614        let expr4 = expr3.tanh();
615        let mut tape = CompiledExpr::from_expr(expr4);
616
617        tape.learn(1e-09);
618
619        let expr3_grad = tape.get_grad_by_name("expr3").unwrap();
620        let operand1_grad = tape.get_grad_by_name("operand1").unwrap();
621        let operand2_grad = tape.get_grad_by_name("operand2").unwrap();
622
623        assert_eq!(expr3_grad, 0.009866037165440211);
624        assert_eq!(operand1_grad, 0.009866037165440211);
625        assert_eq!(operand2_grad, 0.009866037165440211);
626    }
627
628    #[test]
629    fn test_backpropagation_karpathys_example() {
630        let mut x1 = Expr::new_leaf(2.0);
631        x1.name = Some("x1".to_string());
632        let mut x2 = Expr::new_leaf(0.0);
633        x2.name = Some("x2".to_string());
634        let mut w1 = Expr::new_leaf(-3.0);
635        w1.name = Some("w1".to_string());
636        let mut w2 = Expr::new_leaf(1.0);
637        w2.name = Some("w2".to_string());
638        let mut b = Expr::new_leaf(6.8813735870195432);
639        b.name = Some("b".to_string());
640
641        let mut x1w1 = x1 * w1;
642        x1w1.name = Some("x1w1".to_string());
643        let mut x2w2 = x2 * w2;
644        x2w2.name = Some("x2w2".to_string());
645        let mut x1w1_x2w2 = x1w1 + x2w2;
646        x1w1_x2w2.name = Some("x1w1_x2w2".to_string());
647        let mut n = x1w1_x2w2 + b;
648        n.name = Some("n".to_string());
649        let o = n.tanh();
650        let mut tape = CompiledExpr::from_expr(o);
651
652        tape.learn(1e-09);
653
654        let n_grad = tape.get_grad_by_name("n").unwrap();
655        assert_float_eq(n_grad, 0.5);
656
657        let x1w1_x2w2_grad = tape.get_grad_by_name("x1w1_x2w2").unwrap();
658        assert_float_eq(x1w1_x2w2_grad, 0.5);
659
660        let b_grad = tape.get_grad_by_name("b").unwrap();
661        assert_float_eq(b_grad, 0.5);
662
663        let x1w1_grad = tape.get_grad_by_name("x1w1").unwrap();
664        assert_float_eq(x1w1_grad, 0.5);
665
666        let x2w2_grad = tape.get_grad_by_name("x2w2").unwrap();
667        assert_float_eq(x2w2_grad, 0.5);
668
669        let x1_grad = tape.get_grad_by_name("x1").unwrap();
670        assert_float_eq(x1_grad, -1.5);
671
672        let w1_grad = tape.get_grad_by_name("w1").unwrap();
673        assert_float_eq(w1_grad, 1.0);
674
675        let x2_grad = tape.get_grad_by_name("x2").unwrap();
676        assert_float_eq(x2_grad, 0.5);
677
678        let w2_grad = tape.get_grad_by_name("w2").unwrap();
679        assert_float_eq(w2_grad, 0.0);
680    }
681}