Trait alumina::ops::Operation
[−]
[src]
pub trait Operation: OperationClone { fn name(&self) -> &str; fn propagate_shape_constraints(
&self,
nodes: &[Node],
shapes: &mut [NodeShape]
); fn num_params(&self) -> usize; fn input_node_IDs(&self) -> Vec<NodeID>; fn output_node_IDs(&self) -> Vec<NodeID>; fn forward(&mut self, data: &mut [RefCell<NodeData>], params: &[f32]); fn backward(
&mut self,
data: &mut [RefCell<NodeData>],
params: &[f32],
param_deriv: &mut [f32],
error: &mut f32
); fn init_params(&mut self, params: &mut [f32]) { ... } }
Required Methods
fn name(&self) -> &str
fn propagate_shape_constraints(&self, nodes: &[Node], shapes: &mut [NodeShape])
fn num_params(&self) -> usize
fn input_node_IDs(&self) -> Vec<NodeID>
fn output_node_IDs(&self) -> Vec<NodeID>
fn forward(&mut self, data: &mut [RefCell<NodeData>], params: &[f32])
should update output node values based on input node values. Must use += when writing to output node.
fn backward(
&mut self,
data: &mut [RefCell<NodeData>],
params: &[f32],
param_deriv: &mut [f32],
error: &mut f32
)
&mut self,
data: &mut [RefCell<NodeData>],
params: &[f32],
param_deriv: &mut [f32],
error: &mut f32
)
Should calculate error gradient contribution of operation to the input node and parameters based on the output node derivatives. Each operation will be passed its relevant slice for params and param_derivs Note: all calculations should use += as to not overwrite other operations contributions, and in the case of data shape n>1 the sum of parameter gradients from all individual examples should be accumulated in param_deriv and error the graph will later divide by n to get the mean error and error derivatives.
Provided Methods
fn init_params(&mut self, params: &mut [f32])
Implementors
impl Operation for LeakyReLU
impl Operation for BeLU
impl<F: ActivationFunc + 'static> Operation for GenericActivation<F>
impl Operation for SoftMax
impl Operation for Convolution
impl Operation for LinearInterp
impl Operation for ShapeConstraint
impl Operation for Pooling
impl Operation for Collapse
impl Operation for Expand
impl Operation for LinearMap
impl Operation for Bias
impl Operation for L2Regularisation
impl Operation for Scale
impl Operation for MseLoss
impl Operation for MaeLoss
impl Operation for CrossEntLoss
impl Operation for SoftMaxCrossEntLoss
impl Operation for SoftMaxDampedCrossEntLoss
impl Operation for PredictionLoss
impl Operation for Broadcast
impl Operation for GlobalAvg