pub trait OpTrait {
// Required methods
fn get_name(&self) -> &'static str;
fn get_input_size(&self) -> usize;
fn get_output_size(&self) -> usize;
fn apply(&self, input: &[Tensor], output: &[Tensor]);
fn grad(
&self,
input: &[Tensor],
output_grad: &[Tensor],
input_grad: &[Tensor],
);
fn get_values(&self) -> Vec<Tensor>;
fn set_values(&self, v: &[Tensor]);
fn get_grads(&self) -> Vec<Tensor>;
fn as_any(&self) -> &dyn Any;
}
Expand description
Implement operator by this trait to allow the operator be able to stored in the computation graph.
Required Methods§
Sourcefn get_input_size(&self) -> usize
fn get_input_size(&self) -> usize
The number of input needs by this op.
Sourcefn get_output_size(&self) -> usize
fn get_output_size(&self) -> usize
The number of output produced by this op.
Sourcefn grad(&self, input: &[Tensor], output_grad: &[Tensor], input_grad: &[Tensor])
fn grad(&self, input: &[Tensor], output_grad: &[Tensor], input_grad: &[Tensor])
Given the forward input value and backward output_grad, Update weight gradient. return backward input gradeint.
Sourcefn get_values(&self) -> Vec<Tensor>
fn get_values(&self) -> Vec<Tensor>
access weight values