pub trait OpTrait {
fn get_name(&self) -> &'static str;
fn get_input_size(&self) -> usize;
fn get_output_size(&self) -> usize;
fn apply(&self, input: &[Tensor], output: &[Tensor]);
fn grad(
&self,
input: &[Tensor],
output_grad: &[Tensor],
input_grad: &[Tensor]
);
fn get_values(&self) -> Vec<Tensor>;
fn set_values(&self, v: &[Tensor]);
fn get_grads(&self) -> Vec<Tensor>;
fn as_any(&self) -> &dyn Any;
}
Expand description
Implement operator by this trait to allow the operator be able to stored in the computation graph.
Required methods
fn get_input_size(&self) -> usize
fn get_input_size(&self) -> usize
The number of input needs by this op.
fn get_output_size(&self) -> usize
fn get_output_size(&self) -> usize
The number of output produced by this op.
Given the forward input value and backward output_grad, Update weight gradient. return backward input gradeint.
fn get_values(&self) -> Vec<Tensor>
fn get_values(&self) -> Vec<Tensor>
access weight values