pub struct CompiledExpr {
pub results: Vec<f64>,
pub gradients: Vec<f64>,
/* private fields */
}Expand description
A compiled version of an expression tree.
This struct represents a compiled version of an expression tree, which is faster to compute and backpropagate through than the regular expression tree.
Fields§
§results: Vec<f64>The results of the operations in the expression.
gradients: Vec<f64>The gradients of the operations in the expression.
Implementations§
Source§impl CompiledExpr
impl CompiledExpr
Sourcepub fn from_expr(expr: Expr) -> Self
pub fn from_expr(expr: Expr) -> Self
Creates a new CompiledExpr from an expression.
This method consumes the expression and transforms it into a compiled form that is more efficient for computation and backpropagation.
Example:
use alpha_micrograd_rust::value::Expr;
use alpha_micrograd_rust::compiled::CompiledExpr;
let expr = Expr::new_leaf(1.0);
let expr2 = expr.tanh();
let compiled = CompiledExpr::from_expr(expr2);Sourcepub fn recalculate(&mut self)
pub fn recalculate(&mut self)
Recalculates the expression based on the current values.
This method recalculates the expression based on the current values of the parameters. It is more efficient than recalculating the expression tree, as it iterates through an array of operations instead of traversing a tree structure.
Example:
use alpha_micrograd_rust::value::Expr;
use alpha_micrograd_rust::compiled::CompiledExpr;
let expr = Expr::new_leaf_with_name(1.0, "x");
let expr2 = expr.tanh();
let mut compiled = CompiledExpr::from_expr(expr2);
assert_eq!(compiled.result(), 0.7615941559557649);
// Modify the value of "x"
compiled.set("x", 2.0);
compiled.recalculate();
assert_eq!(compiled.result(), 0.9640275800758169);Sourcepub fn learn(&mut self, learning_rate: f64)
pub fn learn(&mut self, learning_rate: f64)
Performs one step of learning (backpropagation) on the compiled expression.
This function updates the values of the learnable parameters in the expression based on the gradients calculated during backpropagation.
§Arguments
learning_rate- The learning rate to use for updating the parameters.
§Returns
Returns nothing. The results are updated in place.
Applies backpropagation to the expression, updating the values of the gradients and the expression itself.
This method will change the gradients based on the gradient of the last expression in the calculation graph. After adjusting the gradients, the method will update the values of the individual expression nodes (parameters) to minimize the loss function.
Example:
use alpha_micrograd_rust::value::Expr;
use alpha_micrograd_rust::compiled::CompiledExpr;
let expr = Expr::new_leaf(1.0);
let expr2 = expr.tanh();
let mut compiled = CompiledExpr::from_expr(expr2);
compiled.learn(1e-09);
compiled.recalculate();Sourcepub fn result(&self) -> f64
pub fn result(&self) -> f64
Returns the final result of the compiled expression.
This function returns the last result in the results vector, which corresponds to the final output of the expression.
§Returns
Returns the final result as a f64 value.
Sourcepub fn get_grad_by_name(&self, name: &str) -> Option<f64>
pub fn get_grad_by_name(&self, name: &str) -> Option<f64>
Gets the gradient of a learnable parameter by its name.
This function retrieves the gradient of a learnable parameter (e.g., a weight or bias) by looking up its name in the names-to-index mapping.
§Arguments
name- The name of the parameter to get the gradient for.
§Returns
Returns an Option<f64> containing the gradient value if found, or None if not found.
Sourcepub fn set(&mut self, name: &str, value: f64)
pub fn set(&mut self, name: &str, value: f64)
Sets the value of a parameter by its name.
This method sets the value of a parameter in the compiled expression. It is used to modify the values of leaf nodes in the expression tree.
Example:
use alpha_micrograd_rust::value::Expr;
use alpha_micrograd_rust::compiled::CompiledExpr;
let expr = Expr::new_leaf_with_name(1.0, "x");
let expr2 = expr.tanh();
let mut compiled = CompiledExpr::from_expr(expr2);
compiled.set("x", 2.0);
compiled.recalculate();
assert_eq!(compiled.result(), 0.9640275800758169);