pub trait Operation: Debug {
fn eval(&self, inputs: &[ArrayViewD<'_, f32>]) -> ArrayD<f32>;
fn grad(
&self,
inputs: &[ArrayViewD<'_, f32>],
loss: ArrayViewD<'_, f32>
) -> Vec<ArrayD<f32>>;
}
Expand description
Represents a differentiable function in a computation graph.
Operations hold their own hyperparameters but not their parameters, values or losses.
Unfortunately boxed traits cannot be saved with serde. When reloaded they will be replaced
by Box<arithmetic::Add>
nodes. When reloading a model with custom Operations, you need to
replace them manually.
Required Methods§
sourcefn eval(&self, inputs: &[ArrayViewD<'_, f32>]) -> ArrayD<f32>
fn eval(&self, inputs: &[ArrayViewD<'_, f32>]) -> ArrayD<f32>
Mutates Outputs based on inputs.
TODO consider modifying output ArrayD
sourcefn grad(
&self,
inputs: &[ArrayViewD<'_, f32>],
loss: ArrayViewD<'_, f32>
) -> Vec<ArrayD<f32>>
fn grad(
&self,
inputs: &[ArrayViewD<'_, f32>],
loss: ArrayViewD<'_, f32>
) -> Vec<ArrayD<f32>>
Returns gradients of inputs wrt outputs.
Note the inputs and output vectors should be the same length.
TODO consider modifying output ArrayD